-
Notifications
You must be signed in to change notification settings - Fork 15k
[mlir][spirv] Add basic support for SPV_EXT_replicated_composites #147067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv] Add basic support for SPV_EXT_replicated_composites #147067
Conversation
This patch introduces two new ops to the SPIR-V dialect: - `spirv.EXT.ConstantCompositeReplicate` - `spirv.EXT.SpecConstantCompositeReplicate` These ops represent composite constants and specialization constants, respectively, constructed by replicating a single splat constant across all elements. They correspond to `SPV_EXT_replicated_composites` extension instructions: - `OpConstantCompositeReplicatedEXT` - `OpSpecConstantCompositeReplicatedEXT` No transformation to these new ops has been introduced in this patch. This approach is chosen as per the discussions on RFC https://discourse.llvm.org/t/rfc-basic-support-for-spv-ext-replicated-composites-in-mlir-spir-v-compile-time-constant-lowering-only/86987 Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-spirv Author: Mohammadreza Ameri Mahabadian (mahabadm) ChangesThis patch introduces two new ops to the SPIR-V dialect:
These ops represent composite constants and specialization constants, respectively, constructed by replicating a single splat constant across all elements. They correspond to
No transformation to these new ops has been introduced in this patch. This approach is chosen as per the discussions on RFC https://discourse.llvm.org/t/rfc-basic-support-for-spv-ext-replicated-composites-in-mlir-spir-v-compile-time-constant-lowering-only/86987 Patch is 36.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147067.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d874817e6888d..6c24dbc613c82 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -359,6 +359,7 @@ def SPV_EXT_shader_atomic_float_min_max : I32EnumAttrCase<"SPV_EXT_shader_atomi
def SPV_EXT_shader_image_int64 : I32EnumAttrCase<"SPV_EXT_shader_image_int64", 1010>;
def SPV_EXT_shader_atomic_float16_add : I32EnumAttrCase<"SPV_EXT_shader_atomic_float16_add", 1011>;
def SPV_EXT_mesh_shader : I32EnumAttrCase<"SPV_EXT_mesh_shader", 1012>;
+def SPV_EXT_replicated_composites : I32EnumAttrCase<"SPV_EXT_replicated_composites", 1013>;
def SPV_AMD_gpu_shader_half_float_fetch : I32EnumAttrCase<"SPV_AMD_gpu_shader_half_float_fetch", 2000>;
def SPV_AMD_shader_ballot : I32EnumAttrCase<"SPV_AMD_shader_ballot", 2001>;
@@ -446,7 +447,7 @@ def SPIRV_ExtensionAttr :
SPV_EXT_shader_stencil_export, SPV_EXT_shader_viewport_index_layer,
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
- SPV_EXT_mesh_shader,
+ SPV_EXT_mesh_shader, SPV_EXT_replicated_composites,
SPV_ARM_tensors,
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
@@ -849,6 +850,12 @@ def SPIRV_C_CooperativeMatrixKHR : I32EnumAttrCase<"Coope
MinVersion<SPIRV_V_1_6>
];
}
+def SPIRV_C_ReplicatedCompositesEXT : I32EnumAttrCase<"ReplicatedCompositesEXT", 6024> {
+ list<Availability> availability = [
+ Extension<[SPV_EXT_replicated_composites]>,
+ MinVersion<SPIRV_V_1_0>
+ ];
+}
def SPIRV_C_BitInstructions : I32EnumAttrCase<"BitInstructions", 6025> {
list<Availability> availability = [
Extension<[SPV_KHR_bit_instructions]>
@@ -1500,7 +1507,7 @@ def SPIRV_CapabilityAttr :
SPIRV_C_USMStorageClassesINTEL, SPIRV_C_IOPipesINTEL, SPIRV_C_BlockingPipesINTEL,
SPIRV_C_FPGARegINTEL, SPIRV_C_DotProductInputAll,
SPIRV_C_DotProductInput4x8BitPacked, SPIRV_C_DotProduct, SPIRV_C_RayCullMaskKHR,
- SPIRV_C_CooperativeMatrixKHR,
+ SPIRV_C_CooperativeMatrixKHR, SPIRV_C_ReplicatedCompositesEXT,
SPIRV_C_BitInstructions, SPIRV_C_AtomicFloat32AddEXT, SPIRV_C_AtomicFloat64AddEXT,
SPIRV_C_LongConstantCompositeINTEL, SPIRV_C_OptNoneINTEL,
SPIRV_C_AtomicFloat16AddEXT, SPIRV_C_DebugInfoModuleINTEL, SPIRV_C_SplitBarrierINTEL,
@@ -4564,6 +4571,8 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMa
def SPIRV_OC_OpCooperativeMatrixStoreKHR : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
def SPIRV_OC_OpCooperativeMatrixMulAddKHR : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
+def SPIRV_OC_OpConstantCompositeReplicateEXT : I32EnumAttrCase<"OpConstantCompositeReplicateEXT", 4461>;
+def SPIRV_OC_OpSpecConstantCompositeReplicateEXT : I32EnumAttrCase<"OpSpecConstantCompositeReplicateEXT", 4462>;
def SPIRV_OC_OpEmitMeshTasksEXT : I32EnumAttrCase<"OpEmitMeshTasksEXT", 5294>;
def SPIRV_OC_OpSetMeshOutputsEXT : I32EnumAttrCase<"OpSetMeshOutputsEXT", 5295>;
def SPIRV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
@@ -4672,6 +4681,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixStoreKHR,
SPIRV_OC_OpCooperativeMatrixMulAddKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
+ SPIRV_OC_OpConstantCompositeReplicateEXT,
+ SPIRV_OC_OpSpecConstantCompositeReplicateEXT,
SPIRV_OC_OpEmitMeshTasksEXT, SPIRV_OC_OpSetMeshOutputsEXT,
SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index c5a85f881b35e..0a5b01fe9e8d0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -135,6 +135,52 @@ def SPIRV_ConstantOp : SPIRV_Op<"Constant",
let autogenSerialization = 0;
}
+
+// -----
+
+def SPIRV_EXTConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"ConstantCompositeReplicate", [Pure]> {
+ let summary = [{
+ Declare a new replicated composite constant op.
+ }];
+
+ let description = [{
+ This op declares a `spiv.EXT.ConstantCompositeReplicate` which represents a
+ splat composite constant i.e. all element of composite constant have the
+ same value. This op will be serialized to SPIR-V `OpConstantCompositeReplicateEXT`.
+ The splat value must come from a non-specialization constant instruction."
+
+ #### Example:
+
+ ```mlir
+ %0 = spirv.Constant 1 : i32
+ %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xf32>
+
+ %2 = spirv.Constant dense<[1, 2]> : vector<2xi32>
+ %3 = spirv.EXT.ConstantCompositeReplicate %2 : !spirv.array<2 x vector<2xi32>>
+
+ %5 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
+ %6 = spirv.EXT.ConstantCompositeReplicate %5 : !spirv.array<2 x vector<2xi32>>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_EXT_replicated_composites]>,
+ Capability<[SPIRV_C_ReplicatedCompositesEXT]>
+ ];
+
+ let arguments = (ins
+ SPIRV_Type:$constant
+ );
+
+ let results = (outs
+ SPIRV_Composite:$replicated_constant
+ );
+
+ let autogenSerialization = 0;
+}
+
// -----
def SPIRV_EntryPointOp : SPIRV_Op<"EntryPoint", [InModuleScope]> {
@@ -689,6 +735,46 @@ def SPIRV_SpecConstantCompositeOp : SPIRV_Op<"SpecConstantComposite", [
// -----
+def SPIRV_EXTSpecConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"SpecConstantCompositeReplicate", [InModuleScope, Symbol]> {
+ let summary = "Declare a new replicated composite specialization constant op.";
+
+ let description = [{
+ This op declares a `spirv.EXT.SpecConstantCompositeReplicate` which represents
+ a splat specialization composite constant i.e. all element of specialization
+ composite constant have the same value. This op will be serialized to SPIR-V
+ `OpSpecConstantCompositeReplicateEXT`. The splat value must come from a
+ symbol reference of specialization constant instruction.
+
+ #### Example:
+
+ ```mlir
+ spirv.SpecConstant @sc_i32_1 = 1 : i32
+ spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+ spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_EXT_replicated_composites]>,
+ Capability<[SPIRV_C_ReplicatedCompositesEXT]>
+ ];
+
+ let arguments = (ins
+ TypeAttr:$type,
+ StrAttr:$sym_name,
+ SymbolRefAttr:$constituent
+ );
+
+ let results = (outs);
+
+ let autogenSerialization = 0;
+
+}
+
+// -----
+
def SPIRV_SpecConstantOperationOp : SPIRV_Op<"SpecConstantOperation", [
Pure, InFunctionScope,
SingleBlockImplicitTerminator<"YieldOp">]> {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index eb2974d62fdd1..c42b2d45d53a9 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -765,6 +765,67 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
setNameFn(getResult(), specialName.str());
}
+//===----------------------------------------------------------------------===//
+// spirv.EXTConstantCompositeReplicate
+//===----------------------------------------------------------------------===//
+
+ParseResult
+spirv::EXTConstantCompositeReplicateOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::UnresolvedOperand constOperand;
+ Type compositeType;
+ if (parser.parseOperand(constOperand) ||
+ parser.parseColonType(compositeType)) {
+ return failure();
+ }
+
+ if (llvm::isa<TensorType>(compositeType)) {
+ if (parser.parseColonType(compositeType))
+ return failure();
+ }
+
+ auto constType = cast<spirv::CompositeType>(compositeType).getElementType(0);
+ while (auto type = llvm::dyn_cast<spirv::ArrayType>(constType)) {
+ constType = type.getElementType();
+ }
+
+ if (parser.resolveOperand(constOperand, constType, result.operands))
+ return failure();
+
+ return parser.addTypeToList(compositeType, result.types);
+}
+
+void spirv::EXTConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
+ printer << ' ' << getConstant() << " : " << getType();
+}
+
+LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
+ auto compositeType = dyn_cast<spirv::CompositeType>(getType());
+ if (!compositeType)
+ return emitError("result type must be a composite type, but provided ")
+ << getType();
+
+ auto constantDefiningOp = getConstant().getDefiningOp();
+ if (!constantDefiningOp)
+ return this->emitOpError("op defining the splat constant not found");
+
+ auto constantOp = dyn_cast_or_null<spirv::ConstantOp>(constantDefiningOp);
+ auto constantCompositeReplicateOp =
+ dyn_cast_or_null<spirv::EXTConstantCompositeReplicateOp>(
+ constantDefiningOp);
+
+ if (!constantOp && !constantCompositeReplicateOp)
+ return this->emitOpError(
+ "op defining the splat constant is not a spirv.Constant or a "
+ "spirv.EXT.ConstantCompositeReplicate");
+
+ if (constantOp)
+ return verifyConstantType(constantOp, constantOp.getValueAttr(),
+ constantOp.getType());
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.ControlBarrierOp
//===----------------------------------------------------------------------===//
@@ -1866,6 +1927,64 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.EXTSpecConstantCompositeReplicateOp
+//===----------------------------------------------------------------------===//
+
+ParseResult
+spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+
+ StringAttr compositeName;
+ const char *attrName = "spec_const";
+ FlatSymbolRefAttr specConstRef;
+ NamedAttrList attrs;
+ Type type;
+
+ if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
+ result.attributes) ||
+ parser.parseLParen() ||
+ parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
+ parser.parseRParen() || parser.parseColonType(type))
+ return failure();
+
+ StringAttr compositeSpecConstituentName =
+ spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
+ result.name);
+ result.addAttribute(compositeSpecConstituentName, specConstRef);
+
+ StringAttr typeAttrName =
+ spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
+ result.addAttribute(typeAttrName, TypeAttr::get(type));
+
+ return success();
+}
+
+void spirv::EXTSpecConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
+ printer << " ";
+ printer.printSymbolName(getSymName());
+ printer << " (" << this->getConstituent() << ") : " << getType();
+}
+
+LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
+ auto compositeType = dyn_cast<spirv::CompositeType>(getType());
+ if (!compositeType)
+ return emitError("result type must be a composite type, but provided ")
+ << getType();
+
+ auto constituentSpecConstOp =
+ dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
+ (*this)->getParentOp(), this->getConstituent()));
+
+ auto constituentType = constituentSpecConstOp.getDefaultValue().getType();
+ auto compositeElemType = compositeType.getElementType(0);
+ if (constituentType != compositeElemType)
+ return emitError("constituent has incorrect type: expected ")
+ << compositeElemType << ", but provided " << constituentType;
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.SpecConstantOperation
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 55d6a380d0bff..5f52308b4be35 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -45,6 +45,12 @@ Value spirv::Deserializer::getValue(uint32_t id) {
return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
constInfo->first);
}
+ if (auto constCompositeReplicateInfo = getConstantCompositeReplicate(id)) {
+ auto constantId = constCompositeReplicateInfo->first;
+ auto element = getValue(constantId);
+ return opBuilder.create<spirv::EXTConstantCompositeReplicateOp>(
+ unknownLoc, constCompositeReplicateInfo->second, element);
+ }
if (auto varOp = getGlobalVariable(id)) {
auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
@@ -56,10 +62,17 @@ Value spirv::Deserializer::getValue(uint32_t id) {
SymbolRefAttr::get(constOp.getOperation()));
return referenceOfOp.getReference();
}
- if (auto constCompositeOp = getSpecConstantComposite(id)) {
+ if (auto specConstCompositeOp = getSpecConstantComposite(id)) {
+ auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
+ unknownLoc, specConstCompositeOp.getType(),
+ SymbolRefAttr::get(specConstCompositeOp.getOperation()));
+ return referenceOfOp.getReference();
+ }
+ if (auto specConstCompositeReplicateOp =
+ getSpecConstantCompositeReplicate(id)) {
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
- unknownLoc, constCompositeOp.getType(),
- SymbolRefAttr::get(constCompositeOp.getOperation()));
+ unknownLoc, specConstCompositeReplicateOp.getType(),
+ SymbolRefAttr::get(specConstCompositeReplicateOp.getOperation()));
return referenceOfOp.getReference();
}
if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
@@ -175,8 +188,12 @@ LogicalResult spirv::Deserializer::processInstruction(
return processConstant(operands, /*isSpec=*/true);
case spirv::Opcode::OpConstantComposite:
return processConstantComposite(operands);
+ case spirv::Opcode::OpConstantCompositeReplicateEXT:
+ return processConstantCompositeReplicateEXT(operands);
case spirv::Opcode::OpSpecConstantComposite:
return processSpecConstantComposite(operands);
+ case spirv::Opcode::OpSpecConstantCompositeReplicateEXT:
+ return processSpecConstantCompositeReplicateEXT(operands);
case spirv::Opcode::OpSpecConstantOp:
return processSpecConstantOperation(operands);
case spirv::Opcode::OpConstantTrue:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b1abd8b3dffe9..2163ccff93c83 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -678,6 +678,14 @@ spirv::Deserializer::getConstant(uint32_t id) {
return constIt->getSecond();
}
+std::optional<std::pair<uint32_t, Type>>
+spirv::Deserializer::getConstantCompositeReplicate(uint32_t id) {
+ auto constIt = constantCompositeReplicateMap.find(id);
+ if (constIt == constantCompositeReplicateMap.end())
+ return std::nullopt;
+ return constIt->getSecond();
+}
+
std::optional<spirv::SpecConstOperationMaterializationInfo>
spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
auto constIt = specConstOperationMap.find(id);
@@ -1554,15 +1562,58 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
+ ArrayRef<uint32_t> operands) {
+
+ if (operands.size() != 3) {
+ return emitError(
+ unknownLoc,
+ "OpConstantCompositeReplicateEXT must have type <id> and result <id> "
+ "and only one parameter which is <id> of splat constant");
+ }
+
+ Type resultType = getType(operands[0]);
+ if (!resultType) {
+ return emitError(unknownLoc, "undefined result type from <id> ")
+ << operands[0];
+ }
+
+ auto compositeType = dyn_cast<CompositeType>(resultType);
+ if (!compositeType) {
+ return emitError(unknownLoc,
+ "result type from <id> is not a composite type")
+ << operands[0];
+ }
+
+ auto resultID = operands[1];
+ auto constantID = operands[2];
+
+ auto constantInfo = getConstant(constantID);
+ auto replicatedConstantCompositeInfo =
+ getConstantCompositeReplicate(constantID);
+ if (!constantInfo && !replicatedConstantCompositeInfo) {
+ return emitError(unknownLoc,
+ "OpConstantCompositeReplicateEXT operand <id> ")
+ << constantID
+ << " must come from a normal constant or a "
+ "OpConstantCompositeReplicateEXT";
+ }
+
+ constantCompositeReplicateMap.try_emplace(resultID, constantID, resultType);
+
+ return success();
+}
+
LogicalResult
spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
if (operands.size() < 2) {
- return emitError(unknownLoc,
- "OpConstantComposite must have type <id> and result <id>");
+ return emitError(
+ unknownLoc,
+ "OpSpecConstantComposite must have type <id> and result <id>");
}
if (operands.size() < 3) {
return emitError(unknownLoc,
- "OpConstantComposite must have at least 1 parameter");
+ "OpSpecConstantComposite must have at least 1 parameter");
}
Type resultType = getType(operands[0]);
@@ -1589,6 +1640,42 @@ spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult spirv::Deserializer::processSpecConstantCompositeReplicateEXT(
+ ArrayRef<uint32_t> operands) {
+
+ if (operands.size() != 3) {
+ return emitError(unknownLoc,
+ "OpSpecConstantCompositeReplicateEXT must have "
+ "type <id> and result <id> and only one parameter which "
+ "is <id> of splat constant");
+ }
+
+ Type resultType = getType(operands[0]);
+ if (!resultType) {
+ return emitError(unknownLoc, "undefined result type from <id> ")
+ << operands[0];
+ }
+
+ auto compositeType = dyn_cast<CompositeType>(resultType);
+ if (!compositeType) {
+ return emitError(unknownLoc,
+ "result type from <id> is not a composite type")
+ << operands[0];
+ }
+
+ auto resultID = operands[1];
+
+ auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
+ auto constituentSpecConstantOp = getSpecConstant(operands[2]);
+ auto op = opBuilder.create<spirv::EXTSpecConstantCompositeReplicateOp>(
+ unknownLoc, TypeAttr::get(resultType), symName,
+ SymbolRefAttr::get(constituentSpecConstantOp));
+
+ specConstCompositeReplicateMap[resultID] = op;
+
+ return success();
+}
+
LogicalResult
spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
if (operands.size() < 3)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 1bc9e4a3c75d8..1fdecc3e6fe0d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -190,6 +190,12 @@ class Deseria...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Mohammadreza Ameri Mahabadian (mahabadm) ChangesThis patch introduces two new ops to the SPIR-V dialect:
These ops represent composite constants and specialization constants, respectively, constructed by replicating a single splat constant across all elements. They correspond to
No transformation to these new ops has been introduced in this patch. This approach is chosen as per the discussions on RFC https://discourse.llvm.org/t/rfc-basic-support-for-spv-ext-replicated-composites-in-mlir-spir-v-compile-time-constant-lowering-only/86987 Patch is 36.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147067.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d874817e6888d..6c24dbc613c82 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -359,6 +359,7 @@ def SPV_EXT_shader_atomic_float_min_max : I32EnumAttrCase<"SPV_EXT_shader_atomi
def SPV_EXT_shader_image_int64 : I32EnumAttrCase<"SPV_EXT_shader_image_int64", 1010>;
def SPV_EXT_shader_atomic_float16_add : I32EnumAttrCase<"SPV_EXT_shader_atomic_float16_add", 1011>;
def SPV_EXT_mesh_shader : I32EnumAttrCase<"SPV_EXT_mesh_shader", 1012>;
+def SPV_EXT_replicated_composites : I32EnumAttrCase<"SPV_EXT_replicated_composites", 1013>;
def SPV_AMD_gpu_shader_half_float_fetch : I32EnumAttrCase<"SPV_AMD_gpu_shader_half_float_fetch", 2000>;
def SPV_AMD_shader_ballot : I32EnumAttrCase<"SPV_AMD_shader_ballot", 2001>;
@@ -446,7 +447,7 @@ def SPIRV_ExtensionAttr :
SPV_EXT_shader_stencil_export, SPV_EXT_shader_viewport_index_layer,
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
- SPV_EXT_mesh_shader,
+ SPV_EXT_mesh_shader, SPV_EXT_replicated_composites,
SPV_ARM_tensors,
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
@@ -849,6 +850,12 @@ def SPIRV_C_CooperativeMatrixKHR : I32EnumAttrCase<"Coope
MinVersion<SPIRV_V_1_6>
];
}
+def SPIRV_C_ReplicatedCompositesEXT : I32EnumAttrCase<"ReplicatedCompositesEXT", 6024> {
+ list<Availability> availability = [
+ Extension<[SPV_EXT_replicated_composites]>,
+ MinVersion<SPIRV_V_1_0>
+ ];
+}
def SPIRV_C_BitInstructions : I32EnumAttrCase<"BitInstructions", 6025> {
list<Availability> availability = [
Extension<[SPV_KHR_bit_instructions]>
@@ -1500,7 +1507,7 @@ def SPIRV_CapabilityAttr :
SPIRV_C_USMStorageClassesINTEL, SPIRV_C_IOPipesINTEL, SPIRV_C_BlockingPipesINTEL,
SPIRV_C_FPGARegINTEL, SPIRV_C_DotProductInputAll,
SPIRV_C_DotProductInput4x8BitPacked, SPIRV_C_DotProduct, SPIRV_C_RayCullMaskKHR,
- SPIRV_C_CooperativeMatrixKHR,
+ SPIRV_C_CooperativeMatrixKHR, SPIRV_C_ReplicatedCompositesEXT,
SPIRV_C_BitInstructions, SPIRV_C_AtomicFloat32AddEXT, SPIRV_C_AtomicFloat64AddEXT,
SPIRV_C_LongConstantCompositeINTEL, SPIRV_C_OptNoneINTEL,
SPIRV_C_AtomicFloat16AddEXT, SPIRV_C_DebugInfoModuleINTEL, SPIRV_C_SplitBarrierINTEL,
@@ -4564,6 +4571,8 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMa
def SPIRV_OC_OpCooperativeMatrixStoreKHR : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
def SPIRV_OC_OpCooperativeMatrixMulAddKHR : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
+def SPIRV_OC_OpConstantCompositeReplicateEXT : I32EnumAttrCase<"OpConstantCompositeReplicateEXT", 4461>;
+def SPIRV_OC_OpSpecConstantCompositeReplicateEXT : I32EnumAttrCase<"OpSpecConstantCompositeReplicateEXT", 4462>;
def SPIRV_OC_OpEmitMeshTasksEXT : I32EnumAttrCase<"OpEmitMeshTasksEXT", 5294>;
def SPIRV_OC_OpSetMeshOutputsEXT : I32EnumAttrCase<"OpSetMeshOutputsEXT", 5295>;
def SPIRV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
@@ -4672,6 +4681,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixStoreKHR,
SPIRV_OC_OpCooperativeMatrixMulAddKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
+ SPIRV_OC_OpConstantCompositeReplicateEXT,
+ SPIRV_OC_OpSpecConstantCompositeReplicateEXT,
SPIRV_OC_OpEmitMeshTasksEXT, SPIRV_OC_OpSetMeshOutputsEXT,
SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index c5a85f881b35e..0a5b01fe9e8d0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -135,6 +135,52 @@ def SPIRV_ConstantOp : SPIRV_Op<"Constant",
let autogenSerialization = 0;
}
+
+// -----
+
+def SPIRV_EXTConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"ConstantCompositeReplicate", [Pure]> {
+ let summary = [{
+ Declare a new replicated composite constant op.
+ }];
+
+ let description = [{
+ This op declares a `spiv.EXT.ConstantCompositeReplicate` which represents a
+ splat composite constant i.e. all element of composite constant have the
+ same value. This op will be serialized to SPIR-V `OpConstantCompositeReplicateEXT`.
+ The splat value must come from a non-specialization constant instruction."
+
+ #### Example:
+
+ ```mlir
+ %0 = spirv.Constant 1 : i32
+ %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xf32>
+
+ %2 = spirv.Constant dense<[1, 2]> : vector<2xi32>
+ %3 = spirv.EXT.ConstantCompositeReplicate %2 : !spirv.array<2 x vector<2xi32>>
+
+ %5 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
+ %6 = spirv.EXT.ConstantCompositeReplicate %5 : !spirv.array<2 x vector<2xi32>>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_EXT_replicated_composites]>,
+ Capability<[SPIRV_C_ReplicatedCompositesEXT]>
+ ];
+
+ let arguments = (ins
+ SPIRV_Type:$constant
+ );
+
+ let results = (outs
+ SPIRV_Composite:$replicated_constant
+ );
+
+ let autogenSerialization = 0;
+}
+
// -----
def SPIRV_EntryPointOp : SPIRV_Op<"EntryPoint", [InModuleScope]> {
@@ -689,6 +735,46 @@ def SPIRV_SpecConstantCompositeOp : SPIRV_Op<"SpecConstantComposite", [
// -----
+def SPIRV_EXTSpecConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"SpecConstantCompositeReplicate", [InModuleScope, Symbol]> {
+ let summary = "Declare a new replicated composite specialization constant op.";
+
+ let description = [{
+ This op declares a `spirv.EXT.SpecConstantCompositeReplicate` which represents
+ a splat specialization composite constant i.e. all element of specialization
+ composite constant have the same value. This op will be serialized to SPIR-V
+ `OpSpecConstantCompositeReplicateEXT`. The splat value must come from a
+ symbol reference of specialization constant instruction.
+
+ #### Example:
+
+ ```mlir
+ spirv.SpecConstant @sc_i32_1 = 1 : i32
+ spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+ spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_EXT_replicated_composites]>,
+ Capability<[SPIRV_C_ReplicatedCompositesEXT]>
+ ];
+
+ let arguments = (ins
+ TypeAttr:$type,
+ StrAttr:$sym_name,
+ SymbolRefAttr:$constituent
+ );
+
+ let results = (outs);
+
+ let autogenSerialization = 0;
+
+}
+
+// -----
+
def SPIRV_SpecConstantOperationOp : SPIRV_Op<"SpecConstantOperation", [
Pure, InFunctionScope,
SingleBlockImplicitTerminator<"YieldOp">]> {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index eb2974d62fdd1..c42b2d45d53a9 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -765,6 +765,67 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
setNameFn(getResult(), specialName.str());
}
+//===----------------------------------------------------------------------===//
+// spirv.EXTConstantCompositeReplicate
+//===----------------------------------------------------------------------===//
+
+ParseResult
+spirv::EXTConstantCompositeReplicateOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::UnresolvedOperand constOperand;
+ Type compositeType;
+ if (parser.parseOperand(constOperand) ||
+ parser.parseColonType(compositeType)) {
+ return failure();
+ }
+
+ if (llvm::isa<TensorType>(compositeType)) {
+ if (parser.parseColonType(compositeType))
+ return failure();
+ }
+
+ auto constType = cast<spirv::CompositeType>(compositeType).getElementType(0);
+ while (auto type = llvm::dyn_cast<spirv::ArrayType>(constType)) {
+ constType = type.getElementType();
+ }
+
+ if (parser.resolveOperand(constOperand, constType, result.operands))
+ return failure();
+
+ return parser.addTypeToList(compositeType, result.types);
+}
+
+void spirv::EXTConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
+ printer << ' ' << getConstant() << " : " << getType();
+}
+
+LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
+ auto compositeType = dyn_cast<spirv::CompositeType>(getType());
+ if (!compositeType)
+ return emitError("result type must be a composite type, but provided ")
+ << getType();
+
+ auto constantDefiningOp = getConstant().getDefiningOp();
+ if (!constantDefiningOp)
+ return this->emitOpError("op defining the splat constant not found");
+
+ auto constantOp = dyn_cast_or_null<spirv::ConstantOp>(constantDefiningOp);
+ auto constantCompositeReplicateOp =
+ dyn_cast_or_null<spirv::EXTConstantCompositeReplicateOp>(
+ constantDefiningOp);
+
+ if (!constantOp && !constantCompositeReplicateOp)
+ return this->emitOpError(
+ "op defining the splat constant is not a spirv.Constant or a "
+ "spirv.EXT.ConstantCompositeReplicate");
+
+ if (constantOp)
+ return verifyConstantType(constantOp, constantOp.getValueAttr(),
+ constantOp.getType());
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.ControlBarrierOp
//===----------------------------------------------------------------------===//
@@ -1866,6 +1927,64 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.EXTSpecConstantCompositeReplicateOp
+//===----------------------------------------------------------------------===//
+
+ParseResult
+spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+
+ StringAttr compositeName;
+ const char *attrName = "spec_const";
+ FlatSymbolRefAttr specConstRef;
+ NamedAttrList attrs;
+ Type type;
+
+ if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
+ result.attributes) ||
+ parser.parseLParen() ||
+ parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
+ parser.parseRParen() || parser.parseColonType(type))
+ return failure();
+
+ StringAttr compositeSpecConstituentName =
+ spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
+ result.name);
+ result.addAttribute(compositeSpecConstituentName, specConstRef);
+
+ StringAttr typeAttrName =
+ spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
+ result.addAttribute(typeAttrName, TypeAttr::get(type));
+
+ return success();
+}
+
+void spirv::EXTSpecConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
+ printer << " ";
+ printer.printSymbolName(getSymName());
+ printer << " (" << this->getConstituent() << ") : " << getType();
+}
+
+LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
+ auto compositeType = dyn_cast<spirv::CompositeType>(getType());
+ if (!compositeType)
+ return emitError("result type must be a composite type, but provided ")
+ << getType();
+
+ auto constituentSpecConstOp =
+ dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
+ (*this)->getParentOp(), this->getConstituent()));
+
+ auto constituentType = constituentSpecConstOp.getDefaultValue().getType();
+ auto compositeElemType = compositeType.getElementType(0);
+ if (constituentType != compositeElemType)
+ return emitError("constituent has incorrect type: expected ")
+ << compositeElemType << ", but provided " << constituentType;
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.SpecConstantOperation
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 55d6a380d0bff..5f52308b4be35 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -45,6 +45,12 @@ Value spirv::Deserializer::getValue(uint32_t id) {
return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
constInfo->first);
}
+ if (auto constCompositeReplicateInfo = getConstantCompositeReplicate(id)) {
+ auto constantId = constCompositeReplicateInfo->first;
+ auto element = getValue(constantId);
+ return opBuilder.create<spirv::EXTConstantCompositeReplicateOp>(
+ unknownLoc, constCompositeReplicateInfo->second, element);
+ }
if (auto varOp = getGlobalVariable(id)) {
auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
@@ -56,10 +62,17 @@ Value spirv::Deserializer::getValue(uint32_t id) {
SymbolRefAttr::get(constOp.getOperation()));
return referenceOfOp.getReference();
}
- if (auto constCompositeOp = getSpecConstantComposite(id)) {
+ if (auto specConstCompositeOp = getSpecConstantComposite(id)) {
+ auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
+ unknownLoc, specConstCompositeOp.getType(),
+ SymbolRefAttr::get(specConstCompositeOp.getOperation()));
+ return referenceOfOp.getReference();
+ }
+ if (auto specConstCompositeReplicateOp =
+ getSpecConstantCompositeReplicate(id)) {
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
- unknownLoc, constCompositeOp.getType(),
- SymbolRefAttr::get(constCompositeOp.getOperation()));
+ unknownLoc, specConstCompositeReplicateOp.getType(),
+ SymbolRefAttr::get(specConstCompositeReplicateOp.getOperation()));
return referenceOfOp.getReference();
}
if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
@@ -175,8 +188,12 @@ LogicalResult spirv::Deserializer::processInstruction(
return processConstant(operands, /*isSpec=*/true);
case spirv::Opcode::OpConstantComposite:
return processConstantComposite(operands);
+ case spirv::Opcode::OpConstantCompositeReplicateEXT:
+ return processConstantCompositeReplicateEXT(operands);
case spirv::Opcode::OpSpecConstantComposite:
return processSpecConstantComposite(operands);
+ case spirv::Opcode::OpSpecConstantCompositeReplicateEXT:
+ return processSpecConstantCompositeReplicateEXT(operands);
case spirv::Opcode::OpSpecConstantOp:
return processSpecConstantOperation(operands);
case spirv::Opcode::OpConstantTrue:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b1abd8b3dffe9..2163ccff93c83 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -678,6 +678,14 @@ spirv::Deserializer::getConstant(uint32_t id) {
return constIt->getSecond();
}
+std::optional<std::pair<uint32_t, Type>>
+spirv::Deserializer::getConstantCompositeReplicate(uint32_t id) {
+ auto constIt = constantCompositeReplicateMap.find(id);
+ if (constIt == constantCompositeReplicateMap.end())
+ return std::nullopt;
+ return constIt->getSecond();
+}
+
std::optional<spirv::SpecConstOperationMaterializationInfo>
spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
auto constIt = specConstOperationMap.find(id);
@@ -1554,15 +1562,58 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
+ ArrayRef<uint32_t> operands) {
+
+ if (operands.size() != 3) {
+ return emitError(
+ unknownLoc,
+ "OpConstantCompositeReplicateEXT must have type <id> and result <id> "
+ "and only one parameter which is <id> of splat constant");
+ }
+
+ Type resultType = getType(operands[0]);
+ if (!resultType) {
+ return emitError(unknownLoc, "undefined result type from <id> ")
+ << operands[0];
+ }
+
+ auto compositeType = dyn_cast<CompositeType>(resultType);
+ if (!compositeType) {
+ return emitError(unknownLoc,
+ "result type from <id> is not a composite type")
+ << operands[0];
+ }
+
+ auto resultID = operands[1];
+ auto constantID = operands[2];
+
+ auto constantInfo = getConstant(constantID);
+ auto replicatedConstantCompositeInfo =
+ getConstantCompositeReplicate(constantID);
+ if (!constantInfo && !replicatedConstantCompositeInfo) {
+ return emitError(unknownLoc,
+ "OpConstantCompositeReplicateEXT operand <id> ")
+ << constantID
+ << " must come from a normal constant or a "
+ "OpConstantCompositeReplicateEXT";
+ }
+
+ constantCompositeReplicateMap.try_emplace(resultID, constantID, resultType);
+
+ return success();
+}
+
LogicalResult
spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
if (operands.size() < 2) {
- return emitError(unknownLoc,
- "OpConstantComposite must have type <id> and result <id>");
+ return emitError(
+ unknownLoc,
+ "OpSpecConstantComposite must have type <id> and result <id>");
}
if (operands.size() < 3) {
return emitError(unknownLoc,
- "OpConstantComposite must have at least 1 parameter");
+ "OpSpecConstantComposite must have at least 1 parameter");
}
Type resultType = getType(operands[0]);
@@ -1589,6 +1640,42 @@ spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult spirv::Deserializer::processSpecConstantCompositeReplicateEXT(
+ ArrayRef<uint32_t> operands) {
+
+ if (operands.size() != 3) {
+ return emitError(unknownLoc,
+ "OpSpecConstantCompositeReplicateEXT must have "
+ "type <id> and result <id> and only one parameter which "
+ "is <id> of splat constant");
+ }
+
+ Type resultType = getType(operands[0]);
+ if (!resultType) {
+ return emitError(unknownLoc, "undefined result type from <id> ")
+ << operands[0];
+ }
+
+ auto compositeType = dyn_cast<CompositeType>(resultType);
+ if (!compositeType) {
+ return emitError(unknownLoc,
+ "result type from <id> is not a composite type")
+ << operands[0];
+ }
+
+ auto resultID = operands[1];
+
+ auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
+ auto constituentSpecConstantOp = getSpecConstant(operands[2]);
+ auto op = opBuilder.create<spirv::EXTSpecConstantCompositeReplicateOp>(
+ unknownLoc, TypeAttr::get(resultType), symName,
+ SymbolRefAttr::get(constituentSpecConstantOp));
+
+ specConstCompositeReplicateMap[resultID] = op;
+
+ return success();
+}
+
LogicalResult
spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
if (operands.size() < 3)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 1bc9e4a3c75d8..1fdecc3e6fe0d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -190,6 +190,12 @@ class Deseria...
[truncated]
|
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for submitting this PR. Few general and specific comments from me:
-
We probably should also add tests to: https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/SPIRV/IR/availability.mlir
-
I wonder whether it would make sense to add some tests to https://github.com/llvm/llvm-project/tree/main/mlir/test/Dialect/SPIRV/IR including some negative tests?
-
Regarding the use of braces; I feel like in some places they could be dropped. See: https://llvm.org/docs/CodingStandards.html#don-t-use-braces-on-simple-single-statement-bodies-of-if-else-loop-statements But I leave the decision up to you.
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Many thanks for your comments @IgWod-IMG ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left extra two small comments, but I will wait with the final review until you address kuhar's comments, as some of the code may change.
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No more comments from me.
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@kuhar Many thanks again for the great code review. Please let me know if there's anything else needed here. Otherwise, if all looks good, would it be possible to merge please? Thank you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, no more comments from my side
@kuhar I appreciate if you can please update me if there is anything else left to be done from my side on this patch. Thank you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the changes, LGTM. Have you checked if it passes validation with spirv-val? If so, we can merge.
@kuhar Many thanks. Yes, I have checked the results with spirv validator. |
@mahabadm Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
This patch introduces two new ops to the SPIR-V dialect:
spirv.EXT.ConstantCompositeReplicate
spirv.EXT.SpecConstantCompositeReplicate
These ops represent composite constants and specialization constants, respectively, constructed by replicating a single splat constant across all elements. They correspond to
SPV_EXT_replicated_composites
extension instructions:OpConstantCompositeReplicatedEXT
OpSpecConstantCompositeReplicatedEXT
No transformation to these new ops has been introduced in this patch.
This approach is chosen as per the discussions on RFC https://discourse.llvm.org/t/rfc-basic-support-for-spv-ext-replicated-composites-in-mlir-spir-v-compile-time-constant-lowering-only/86987