Skip to content

Commit ead0a97

Browse files
SahilPatidarSahilPatidarantiagainst
authored
[mlir][spirv] Replace hardcoded strings with op methods (#81443)
Progress towards #77627 --------- Co-authored-by: SahilPatidar <[email protected]> Co-authored-by: Lei Zhang <[email protected]>
1 parent 339baae commit ead0a97

File tree

4 files changed

+23
-13
lines changed

4 files changed

+23
-13
lines changed

mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,8 @@ Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
489489
auto attrValue = words[wordIndex++];
490490
auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
491491
static_cast<spirv::MemoryAccess>(attrValue));
492-
attributes.push_back(opBuilder.getNamedAttr("memory_access", attr));
492+
attributes.push_back(
493+
opBuilder.getNamedAttr(attributeName<MemoryAccess>(), attr));
493494
isAlignedAttr = (attrValue == 2);
494495
}
495496

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,11 @@ spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
216216
return emitError(unknownLoc, "OpMemoryModel must have two operands");
217217

218218
(*module)->setAttr(
219-
"addressing_model",
219+
module->getAddressingModelAttrName(),
220220
opBuilder.getAttr<spirv::AddressingModelAttr>(
221221
static_cast<spirv::AddressingModel>(operands.front())));
222-
(*module)->setAttr("memory_model",
222+
223+
(*module)->setAttr(module->getMemoryModelAttrName(),
223224
opBuilder.getAttr<spirv::MemoryModelAttr>(
224225
static_cast<spirv::MemoryModel>(operands.back())));
225226

mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp

+12-8
Original file line numberDiff line numberDiff line change
@@ -709,33 +709,37 @@ Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
709709
operands.push_back(id);
710710
}
711711

712-
if (auto attr = op->getAttr("memory_access")) {
712+
StringAttr memoryAccess = op.getMemoryAccessAttrName();
713+
if (auto attr = op->getAttr(memoryAccess)) {
713714
operands.push_back(
714715
static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
715716
}
716717

717-
elidedAttrs.push_back("memory_access");
718+
elidedAttrs.push_back(memoryAccess.strref());
718719

719-
if (auto attr = op->getAttr("alignment")) {
720+
StringAttr alignment = op.getAlignmentAttrName();
721+
if (auto attr = op->getAttr(alignment)) {
720722
operands.push_back(static_cast<uint32_t>(
721723
cast<IntegerAttr>(attr).getValue().getZExtValue()));
722724
}
723725

724-
elidedAttrs.push_back("alignment");
726+
elidedAttrs.push_back(alignment.strref());
725727

726-
if (auto attr = op->getAttr("source_memory_access")) {
728+
StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
729+
if (auto attr = op->getAttr(sourceMemoryAccess)) {
727730
operands.push_back(
728731
static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
729732
}
730733

731-
elidedAttrs.push_back("source_memory_access");
734+
elidedAttrs.push_back(sourceMemoryAccess.strref());
732735

733-
if (auto attr = op->getAttr("source_alignment")) {
736+
StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
737+
if (auto attr = op->getAttr(sourceAlignment)) {
734738
operands.push_back(static_cast<uint32_t>(
735739
cast<IntegerAttr>(attr).getValue().getZExtValue()));
736740
}
737741

738-
elidedAttrs.push_back("source_alignment");
742+
elidedAttrs.push_back(sourceAlignment.strref());
739743
if (failed(emitDebugLine(functionBody, op.getLoc())))
740744
return failure();
741745
encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,14 @@ void Serializer::processExtension() {
197197
}
198198

199199
void Serializer::processMemoryModel() {
200+
StringAttr memoryModelName = module.getMemoryModelAttrName();
200201
auto mm = static_cast<uint32_t>(
201-
module->getAttrOfType<spirv::MemoryModelAttr>("memory_model").getValue());
202+
module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
203+
.getValue());
204+
205+
StringAttr addressingModelName = module.getAddressingModelAttrName();
202206
auto am = static_cast<uint32_t>(
203-
module->getAttrOfType<spirv::AddressingModelAttr>("addressing_model")
207+
module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
204208
.getValue());
205209

206210
encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});

0 commit comments

Comments
 (0)