Skip to content

[mlir][spirv] Use ODS generated attribute names for op definitions #81552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 26, 2024

Conversation

tw-ilson
Copy link
Contributor

@tw-ilson tw-ilson commented Feb 13, 2024

Since ODS generates getters functions for SPIRV operations' attribute names, we replace instances of these hardcoded strings in the SPIR-V dialect's op parser/printer with function calls for consistency.

Fixes #77627

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:spirv mlir labels Feb 13, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 13, 2024

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: None (tw-ilson)

Changes

Patch is 25.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81552.diff

9 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp (+3-3)
  • (modified) mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp (+6-6)
  • (modified) mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp (+5-3)
  • (modified) mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp (+16-14)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+26-15)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp (+3-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h (+28-27)
  • (modified) mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (+1-1)
diff --git a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
index 7e33e91414e0bf..d84133d5933415 100644
--- a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
@@ -42,8 +42,9 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
                              << stringifyTypeName<ExpectedElementType>()
                              << " value, found " << elementType;
 
+  StringRef semanticsAttrName = spirv::AtomicAndOp::getSemanticsAttrName(op->getName()).strref();
   auto memorySemantics =
-      op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
+      op->getAttrOfType<spirv::MemorySemanticsAttr>(semanticsAttrName)
           .getValue();
   if (failed(verifyMemorySemantics(op, memorySemantics))) {
     return failure();
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 580782043c81b4..650b5448f041b5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -87,7 +87,7 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
         parser.parseRSquare())
       return failure();
 
-    result.addAttribute(kBranchWeightAttrName,
+    result.addAttribute(BranchConditionalOp::getBranchWeightsAttrName(result.name).strref(),
                         builder.getArrayAttr({trueWeight, falseWeight}));
   }
 
@@ -199,11 +199,11 @@ LogicalResult FunctionCallOp::verify() {
 }
 
 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
-  return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
+  return (*this)->getAttrOfType<SymbolRefAttr>(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref());
 }
 
 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
-  (*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
+  (*this)->setAttr(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref(), callee.get<SymbolRefAttr>());
 }
 
 Operation::operand_range FunctionCallOp::getArgOperands() {
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index ac29ab0db586cf..ea82f8442dcfe0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -26,9 +26,9 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
   GroupOperation groupOperation;
   OpAsmParser::UnresolvedOperand valueInfo;
   if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
-                                                kExecutionScopeAttrName) ||
+                                                ControlBarrierOp::getExecutionScopeAttrName(state.name).strref()) ||
       spirv::parseEnumStrAttr<GroupOperationAttr>(groupOperation, parser, state,
-                                                  kGroupOperationAttrName) ||
+                                                  GroupFAddOp::getGroupOperationAttrName(state.name).strref()) ||
       parser.parseOperand(valueInfo))
     return failure();
 
@@ -61,11 +61,11 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
   printer
       << " \""
       << stringifyScope(
-             groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+             groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
                  .getValue())
       << "\" \""
       << stringifyGroupOperation(
-             groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
+             groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
                  .getValue())
       << "\" " << groupOp->getOperand(0);
 
@@ -76,14 +76,14 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
 
 static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
   spirv::Scope scope =
-      groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+      groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
           .getValue();
   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
     return groupOp->emitOpError(
         "execution scope must be 'Workgroup' or 'Subgroup'");
 
   GroupOperation operation =
-      groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
+      groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
           .getValue();
   if (operation == GroupOperation::ClusteredReduce &&
       groupOp->getNumOperands() == 1)
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
index 00fc2acf7f07d0..3abc4361321d04 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
@@ -33,10 +33,11 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
   // ODS enforces that vector 1 and vector 2, and result and the accumulator
   // have the same types.
   Type factorTy = op->getOperand(0).getType();
+  StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
   if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
     auto packedVectorFormat =
         llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
-            op->getAttr(kPackedVectorFormatAttrName));
+            op->getAttr(packedVectorFormatAttrName));
     if (!packedVectorFormat)
       return op->emitOpError("requires Packed Vector Format attribute for "
                              "integer vector operands");
@@ -50,7 +51,7 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
                         "integer vector operands to be 32-bits wide",
                         packedVectorFormat.getValue()));
   } else {
-    if (op->hasAttr(kPackedVectorFormatAttrName))
+    if (op->hasAttr(packedVectorFormatAttrName))
       return op->emitOpError(llvm::formatv(
           "with invalid format attribute for vector operands of type '{0}'",
           factorTy));
@@ -99,9 +100,10 @@ getIntegerDotProductCapabilities(Operation *op) {
   SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
 
   Type factorTy = op->getOperand(0).getType();
+  StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
   if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
     auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
-        op->getAttr(kPackedVectorFormatAttrName));
+        op->getAttr(packedVectorFormatAttrName));
     if (formatAttr.getValue() ==
         spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
       capabilities.push_back(dotProductInput4x8BitPackedCap);
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 0df59e42218b55..eb7900fd6ef8fb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -38,17 +38,19 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
   }
 
   spirv::MemoryAccess memoryAccessAttr;
+  StringRef memoryAccessAttrName = CopyMemoryOp::getMemoryAccessAttrName(state.name).strref();
   if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
-          memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName))
+          memoryAccessAttr, parser, state, memoryAccessAttrName))
     return failure();
 
   if (spirv::bitEnumContainsAll(memoryAccessAttr,
                                 spirv::MemoryAccess::Aligned)) {
     // Parse integer attribute for alignment.
     Attribute alignmentAttr;
+    StringRef alignmentAttrName = CopyMemoryOp::getSourceAlignmentAttrName(state.name).strref();
     Type i32Type = parser.getBuilder().getIntegerType(32);
     if (parser.parseComma() ||
-        parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
+        parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
                               state.attributes)) {
       return failure();
     }
@@ -72,7 +74,7 @@ static void printSourceMemoryAccessAttribute(
   // Print optional memory access attribute.
   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
                                               : memoryOp.getMemoryAccess())) {
-    elidedAttrs.push_back(kSourceMemoryAccessAttrName);
+    elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName().strref());
 
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
 
@@ -80,7 +82,7 @@ static void printSourceMemoryAccessAttribute(
       // Print integer alignment attribute.
       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
                                                : memoryOp.getAlignment())) {
-        elidedAttrs.push_back(kSourceAlignmentAttrName);
+        elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName().strref());
         printer << ", " << *alignment;
       }
     }
@@ -98,7 +100,7 @@ static void printMemoryAccessAttribute(
   // Print optional memory access attribute.
   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
                                               : memoryOp.getMemoryAccess())) {
-    elidedAttrs.push_back(kMemoryAccessAttrName);
+    elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName().strref());
 
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
 
@@ -106,7 +108,7 @@ static void printMemoryAccessAttribute(
       // Print integer alignment attribute.
       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
                                                : memoryOp.getAlignment())) {
-        elidedAttrs.push_back(kAlignmentAttrName);
+        elidedAttrs.push_back(memoryOp.getAlignmentAttrName().strref());
         printer << ", " << *alignment;
       }
     }
@@ -136,11 +138,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
   // memory-access attribute is Aligned, then the alignment attribute must be
   // present.
   auto *op = memoryOp.getOperation();
-  auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
+  auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName().strref());
   if (!memAccessAttr) {
     // Alignment attribute shouldn't be present if memory access attribute is
     // not present.
-    if (op->getAttr(kAlignmentAttrName)) {
+    if (op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
       return memoryOp.emitOpError(
           "invalid alignment specification without aligned memory access "
           "specification");
@@ -157,11 +159,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
 
   if (spirv::bitEnumContainsAll(memAccess.getValue(),
                                 spirv::MemoryAccess::Aligned)) {
-    if (!op->getAttr(kAlignmentAttrName)) {
+    if (!op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
       return memoryOp.emitOpError("missing alignment value");
     }
   } else {
-    if (op->getAttr(kAlignmentAttrName)) {
+    if (op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
       return memoryOp.emitOpError(
           "invalid alignment specification with non-aligned memory access "
           "specification");
@@ -180,11 +182,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
   // memory-access attribute is Aligned, then the alignment attribute must be
   // present.
   auto *op = memoryOp.getOperation();
-  auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
+  auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName().strref());
   if (!memAccessAttr) {
     // Alignment attribute shouldn't be present if memory access attribute is
     // not present.
-    if (op->getAttr(kSourceAlignmentAttrName)) {
+    if (op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
       return memoryOp.emitOpError(
           "invalid alignment specification without aligned memory access "
           "specification");
@@ -201,11 +203,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
 
   if (spirv::bitEnumContainsAll(memAccess.getValue(),
                                 spirv::MemoryAccess::Aligned)) {
-    if (!op->getAttr(kSourceAlignmentAttrName)) {
+    if (!op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
       return memoryOp.emitOpError("missing alignment value");
     }
   } else {
-    if (op->getAttr(kSourceAlignmentAttrName)) {
+    if (op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
       return memoryOp.emitOpError(
           "invalid alignment specification with non-aligned memory access "
           "specification");
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 50035c917137e3..c27df9e30f24e3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -457,12 +457,13 @@ ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
                                              OperationState &result) {
   OpAsmParser::UnresolvedOperand compositeInfo;
   Attribute indicesAttr;
+  StringRef indicesAttrName = spirv::CompositeExtractOp::getIndicesAttrName(result.name).strref();
   Type compositeType;
   SMLoc attrLocation;
 
   if (parser.parseOperand(compositeInfo) ||
       parser.getCurrentLocation(&attrLocation) ||
-      parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
+      parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
       parser.parseColonType(compositeType) ||
       parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
     return failure();
@@ -513,11 +514,12 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
   Type objectType, compositeType;
   Attribute indicesAttr;
+  StringRef indicesAttrName = spirv::CompositeInsertOp::getIndicesAttrName(result.name).strref();
   auto loc = parser.getCurrentLocation();
 
   return failure(
       parser.parseOperandList(operands, 2) ||
-      parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
+      parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
       parser.parseColonType(objectType) ||
       parser.parseKeywordType("into", compositeType) ||
       parser.resolveOperands(operands, {objectType, compositeType}, loc,
@@ -559,7 +561,8 @@ void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
                                      OperationState &result) {
   Attribute value;
-  if (parser.parseAttribute(value, kValueAttrName, result.attributes))
+  StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name).strref();
+  if (parser.parseAttribute(value, valueAttrName, result.attributes))
     return failure();
 
   Type type = NoneType::get(parser.getContext());
@@ -822,7 +825,7 @@ ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
         }))
       return failure();
   }
-  result.addAttribute(kInterfaceAttrName,
+  result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name).strref(),
                       parser.getBuilder().getArrayAttr(interfaceVars));
   return success();
 }
@@ -875,7 +878,8 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
     }
     values.push_back(llvm::cast<IntegerAttr>(value).getInt());
   }
-  result.addAttribute(kValuesAttrName,
+  StringRef valuesAttrName = spirv::ExecutionModeOp::getValuesAttrName(result.name).strref();
+  result.addAttribute(valuesAttrName,
                       parser.getBuilder().getI32ArrayAttr(values));
   return success();
 }
@@ -1150,16 +1154,17 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
                                            OperationState &result) {
   // Parse variable name.
   StringAttr nameAttr;
+  StringRef initializerAttrName = spirv::GlobalVariableOp::getInitializerAttrName(result.name).strref();
   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                              result.attributes)) {
     return failure();
   }
 
   // Parse optional initializer
-  if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
+  if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
     FlatSymbolRefAttr initSymbol;
     if (parser.parseLParen() ||
-        parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
+        parser.parseAttribute(initSymbol, Type(), initializerAttrName,
                               result.attributes) ||
         parser.parseRParen())
       return failure();
@@ -1170,6 +1175,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
   }
 
   Type type;
+  StringRef typeAttrName = spirv::GlobalVariableOp::getTypeAttrName(result.name).strref();
   auto loc = parser.getCurrentLocation();
   if (parser.parseColonType(type)) {
     return failure();
@@ -1177,7 +1183,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
   if (!llvm::isa<spirv::PointerType>(type)) {
     return parser.emitError(loc, "expected spirv.ptr type");
   }
-  result.addAttribute(kTypeAttrName, TypeAttr::get(type));
+  result.addAttribute(typeAttrName, TypeAttr::get(type));
 
   return success();
 }
@@ -1191,15 +1197,17 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
   printer.printSymbolName(getSymName());
   elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
 
+  StringRef initializerAttrName = this->getInitializerAttrName().strref();
   // Print optional initializer
   if (auto initializer = this->getInitializer()) {
-    printer << " " << kInitializerAttrName << '(';
+    printer << " " << initializerAttrName << '(';
     printer.printSymbolName(*initializer);
     printer << ')';
-    elidedAttrs.push_back(kInitializerAttrName);
+    elidedAttrs.push_back(initializerAttrName);
   }
 
-  elidedAttrs.push_back(kTypeAttrName);
+  StringRef typeAttrName = this->getTypeAttrName().strref();
+  elidedAttrs.push_back(typeAttrName);
   spirv::printVariableDecorations(*this, printer, elidedAttrs);
   printer << " : " << getType();
 }
@@ -1220,7 +1228,7 @@ LogicalResult spirv::GlobalVariableOp::verify() {
   }
 
   if (auto init =
-          (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
+          (*this)->getAttrOfType<FlatSymbolRefAttr>(this->getInitializerAttrName().strref())) {
     Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
         (*this)->getParentOp(), init.getAttr());
     // TODO: Currently only variable initialization with specialization
@@ -1594,6 +1602,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
   StringAttr nameAttr;
   Attribute valueAttr;
+  StringRef defaultValueAttrName = spirv::SpecConstantOp::getDefaultValueAttrName(result.name).strref();
 
   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                              result.attributes))
@@ -1609,7 +1618,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
   }
 
   if (parser.parseEqual() ||
-      parser.parseAttribute(valueAttr, kDefaultValueAttrName,
+      parser.parseAttribute(valueAttr, defaultValueAttrName,
                             result.attributes))
     return failure();
 
@@ -1783,14 +1792,16 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
   if (parser.parseRParen())
     return failure();
 
-  result.addAttribute(kCompositeSpecConstituentsName,
+  StringRef compositeSpecConstituentsName = spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name).strref();
+  result.addAttribute(compositeSpecConstituentsName,
                       parser.getBuilder().getArrayAttr(constituents));
 
   Type type;
   if (parser.parseColonType(type))
     return failure();
 
-  result.addAttribute(kTypeAttrName, TypeAttr::get(type));
+  StringRef typeAttrName = spirv::SpecConstantCompositeOp::getTypeAttrName(result.name).strref();
+  result.addAttribute(typeAttrName, TypeAttr::get(type));
 
   return success();
 }
diff --git a/mlir/...
[truncated]

@tw-ilson tw-ilson changed the title replaces attribute names in SPIRV dialect parsing code replaces hardcoded attribute names in SPIRV dialect parsing code Feb 13, 2024
Copy link

github-actions bot commented Feb 13, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@@ -42,8 +42,9 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
<< stringifyTypeName<ExpectedElementType>()
<< " value, found " << elementType;

StringRef semanticsAttrName = spirv::AtomicAndOp::getSemanticsAttrName(op->getName()).strref();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite get why you use StringRef here?
Can we stick to StringAttr instead? This will make it use pointer comparison instead of string comparisons.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, will do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@tw-ilson tw-ilson marked this pull request as draft February 14, 2024 19:26
@tw-ilson
Copy link
Contributor Author

added templated functions instead of removing assert from codegen

@tw-ilson
Copy link
Contributor Author

#77627

@tw-ilson tw-ilson marked this pull request as ready for review February 15, 2024 02:10
@tw-ilson tw-ilson requested a review from joker-eph February 15, 2024 02:11
@@ -875,7 +882,9 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
}
values.push_back(llvm::cast<IntegerAttr>(value).getInt());
}
result.addAttribute(kValuesAttrName,
StringRef valuesAttrName =
spirv::ExecutionModeOp::getValuesAttrName(result.name).strref();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StringRef still?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops.
Otherwise good though?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only superficially looked, this will be more for @antiagainst to approve I think.

I would mentioned that this PR would benefit from a description (the title says the "what" is done, the description should say the "why" and provide context).
(See https://mlir.llvm.org/getting_started/Contributing/#commit-messages )

@antiagainst antiagainst changed the title replaces hardcoded attribute names in SPIRV dialect parsing code [mlir][spirv] Use ODS generated attribute names for op definitions Feb 26, 2024
Copy link
Member

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks for helping to fix this! I just have a few nits; but given it has been some time waiting for reviews (sorry!) so I just went ahead fixed them with cb22157.

@antiagainst antiagainst merged commit 1d5e3b2 into llvm:main Feb 26, 2024
Copy link

@tw-ilson Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested
by our build bots. If there is a problem with a build, you may recieve a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as
the builds can include changes from many authors. It is not uncommon for your
change to be included in a build that fails due to someone else's changes, or
infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself.
This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:spirv mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][spirv] Replace hardcoded attribute strings with op methods
4 participants