diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp index a678124bf4832..5b2903824c9e7 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -489,7 +489,8 @@ Deserializer::processOp(ArrayRef words) { auto attrValue = words[wordIndex++]; auto attr = opBuilder.getAttr( static_cast(attrValue)); - attributes.push_back(opBuilder.getNamedAttr("memory_access", attr)); + attributes.push_back( + opBuilder.getNamedAttr(attributeName(), attr)); isAlignedAttr = (attrValue == 2); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 02d03b3a0faee..83ef01b4e3a46 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -216,10 +216,11 @@ spirv::Deserializer::processMemoryModel(ArrayRef operands) { return emitError(unknownLoc, "OpMemoryModel must have two operands"); (*module)->setAttr( - "addressing_model", + module->getAddressingModelAttrName(), opBuilder.getAttr( static_cast(operands.front()))); - (*module)->setAttr("memory_model", + + (*module)->setAttr(module->getMemoryModelAttrName(), opBuilder.getAttr( static_cast(operands.back()))); diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp index e68ed5efaca74..c283e64fa185a 100644 --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -709,33 +709,37 @@ Serializer::processOp(spirv::CopyMemoryOp op) { operands.push_back(id); } - if (auto attr = op->getAttr("memory_access")) { + StringAttr memoryAccess = op.getMemoryAccessAttrName(); + if (auto attr = op->getAttr(memoryAccess)) { operands.push_back( static_cast(cast(attr).getValue())); } - elidedAttrs.push_back("memory_access"); + elidedAttrs.push_back(memoryAccess.strref()); - if (auto attr = op->getAttr("alignment")) { + StringAttr alignment = op.getAlignmentAttrName(); + if (auto attr = op->getAttr(alignment)) { operands.push_back(static_cast( cast(attr).getValue().getZExtValue())); } - elidedAttrs.push_back("alignment"); + elidedAttrs.push_back(alignment.strref()); - if (auto attr = op->getAttr("source_memory_access")) { + StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName(); + if (auto attr = op->getAttr(sourceMemoryAccess)) { operands.push_back( static_cast(cast(attr).getValue())); } - elidedAttrs.push_back("source_memory_access"); + elidedAttrs.push_back(sourceMemoryAccess.strref()); - if (auto attr = op->getAttr("source_alignment")) { + StringAttr sourceAlignment = op.getSourceAlignmentAttrName(); + if (auto attr = op->getAttr(sourceAlignment)) { operands.push_back(static_cast( cast(attr).getValue().getZExtValue())); } - elidedAttrs.push_back("source_alignment"); + elidedAttrs.push_back(sourceAlignment.strref()); if (failed(emitDebugLine(functionBody, op.getLoc()))) return failure(); encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 40337e007bbf7..4a4e878d8af91 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -197,10 +197,14 @@ void Serializer::processExtension() { } void Serializer::processMemoryModel() { + StringAttr memoryModelName = module.getMemoryModelAttrName(); auto mm = static_cast( - module->getAttrOfType("memory_model").getValue()); + module->getAttrOfType(memoryModelName) + .getValue()); + + StringAttr addressingModelName = module.getAddressingModelAttrName(); auto am = static_cast( - module->getAttrOfType("addressing_model") + module->getAttrOfType(addressingModelName) .getValue()); encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});