Skip to content

Conversation

mahabadm
Copy link
Contributor

@mahabadm mahabadm commented Jul 4, 2025

This patch introduces two new ops to the SPIR-V dialect:

  • spirv.EXT.ConstantCompositeReplicate
  • spirv.EXT.SpecConstantCompositeReplicate

These ops represent composite constants and specialization constants, respectively, constructed by replicating a single splat constant across all elements. They correspond to SPV_EXT_replicated_composites extension instructions:

  • OpConstantCompositeReplicatedEXT
  • OpSpecConstantCompositeReplicatedEXT

No transformation to these new ops has been introduced in this patch.

This approach is chosen as per the discussions on RFC https://discourse.llvm.org/t/rfc-basic-support-for-spv-ext-replicated-composites-in-mlir-spir-v-compile-time-constant-lowering-only/86987

This patch introduces two new ops to the SPIR-V dialect:
- `spirv.EXT.ConstantCompositeReplicate`
- `spirv.EXT.SpecConstantCompositeReplicate`

These ops represent composite constants and specialization constants, respectively, constructed by replicating a single splat constant across all elements. They correspond to `SPV_EXT_replicated_composites` extension instructions:
- `OpConstantCompositeReplicatedEXT`
- `OpSpecConstantCompositeReplicatedEXT`

No transformation to these new ops has been introduced in this patch.

This approach is chosen as per the discussions on RFC https://discourse.llvm.org/t/rfc-basic-support-for-spv-ext-replicated-composites-in-mlir-spir-v-compile-time-constant-lowering-only/86987

Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Copy link

github-actions bot commented Jul 4, 2025

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2025

@llvm/pr-subscribers-mlir-spirv

Author: Mohammadreza Ameri Mahabadian (mahabadm)

Changes

This patch introduces two new ops to the SPIR-V dialect:

  • spirv.EXT.ConstantCompositeReplicate
  • spirv.EXT.SpecConstantCompositeReplicate

These ops represent composite constants and specialization constants, respectively, constructed by replicating a single splat constant across all elements. They correspond to SPV_EXT_replicated_composites extension instructions:

  • OpConstantCompositeReplicatedEXT
  • OpSpecConstantCompositeReplicatedEXT

No transformation to these new ops has been introduced in this patch.

This approach is chosen as per the discussions on RFC https://discourse.llvm.org/t/rfc-basic-support-for-spv-ext-replicated-composites-in-mlir-spir-v-compile-time-constant-lowering-only/86987


Patch is 36.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147067.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+13-2)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td (+86)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+119)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp (+20-3)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+90-3)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+39)
  • (modified) mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp (+36)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+46)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.h (+12)
  • (modified) mlir/test/Target/SPIRV/constant.mlir (+82-1)
  • (modified) mlir/test/Target/SPIRV/spec-constant.mlir (+27)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d874817e6888d..6c24dbc613c82 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -359,6 +359,7 @@ def SPV_EXT_shader_atomic_float_min_max  : I32EnumAttrCase<"SPV_EXT_shader_atomi
 def SPV_EXT_shader_image_int64           : I32EnumAttrCase<"SPV_EXT_shader_image_int64", 1010>;
 def SPV_EXT_shader_atomic_float16_add    : I32EnumAttrCase<"SPV_EXT_shader_atomic_float16_add", 1011>;
 def SPV_EXT_mesh_shader                  : I32EnumAttrCase<"SPV_EXT_mesh_shader", 1012>;
+def SPV_EXT_replicated_composites        : I32EnumAttrCase<"SPV_EXT_replicated_composites", 1013>;
 
 def SPV_AMD_gpu_shader_half_float_fetch          : I32EnumAttrCase<"SPV_AMD_gpu_shader_half_float_fetch", 2000>;
 def SPV_AMD_shader_ballot                        : I32EnumAttrCase<"SPV_AMD_shader_ballot", 2001>;
@@ -446,7 +447,7 @@ def SPIRV_ExtensionAttr :
       SPV_EXT_shader_stencil_export, SPV_EXT_shader_viewport_index_layer,
       SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
       SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
-      SPV_EXT_mesh_shader,
+      SPV_EXT_mesh_shader, SPV_EXT_replicated_composites,
       SPV_ARM_tensors,
       SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
       SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
@@ -849,6 +850,12 @@ def SPIRV_C_CooperativeMatrixKHR                        : I32EnumAttrCase<"Coope
     MinVersion<SPIRV_V_1_6>
   ];
 }
+def SPIRV_C_ReplicatedCompositesEXT                     : I32EnumAttrCase<"ReplicatedCompositesEXT", 6024> {
+  list<Availability> availability = [
+    Extension<[SPV_EXT_replicated_composites]>,
+    MinVersion<SPIRV_V_1_0>
+  ];
+}
 def SPIRV_C_BitInstructions                             : I32EnumAttrCase<"BitInstructions", 6025> {
   list<Availability> availability = [
     Extension<[SPV_KHR_bit_instructions]>
@@ -1500,7 +1507,7 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_USMStorageClassesINTEL, SPIRV_C_IOPipesINTEL, SPIRV_C_BlockingPipesINTEL,
       SPIRV_C_FPGARegINTEL, SPIRV_C_DotProductInputAll,
       SPIRV_C_DotProductInput4x8BitPacked, SPIRV_C_DotProduct, SPIRV_C_RayCullMaskKHR,
-      SPIRV_C_CooperativeMatrixKHR,
+      SPIRV_C_CooperativeMatrixKHR, SPIRV_C_ReplicatedCompositesEXT,
       SPIRV_C_BitInstructions, SPIRV_C_AtomicFloat32AddEXT, SPIRV_C_AtomicFloat64AddEXT,
       SPIRV_C_LongConstantCompositeINTEL, SPIRV_C_OptNoneINTEL,
       SPIRV_C_AtomicFloat16AddEXT, SPIRV_C_DebugInfoModuleINTEL, SPIRV_C_SplitBarrierINTEL,
@@ -4564,6 +4571,8 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR       : I32EnumAttrCase<"OpCooperativeMa
 def SPIRV_OC_OpCooperativeMatrixStoreKHR      : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
 def SPIRV_OC_OpCooperativeMatrixMulAddKHR     : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
 def SPIRV_OC_OpCooperativeMatrixLengthKHR     : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
+def SPIRV_OC_OpConstantCompositeReplicateEXT : I32EnumAttrCase<"OpConstantCompositeReplicateEXT", 4461>;
+def SPIRV_OC_OpSpecConstantCompositeReplicateEXT : I32EnumAttrCase<"OpSpecConstantCompositeReplicateEXT", 4462>;
 def SPIRV_OC_OpEmitMeshTasksEXT               : I32EnumAttrCase<"OpEmitMeshTasksEXT", 5294>;
 def SPIRV_OC_OpSetMeshOutputsEXT              : I32EnumAttrCase<"OpSetMeshOutputsEXT", 5295>;
 def SPIRV_OC_OpSubgroupBlockReadINTEL         : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
@@ -4672,6 +4681,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
       SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixStoreKHR,
       SPIRV_OC_OpCooperativeMatrixMulAddKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
+      SPIRV_OC_OpConstantCompositeReplicateEXT,
+      SPIRV_OC_OpSpecConstantCompositeReplicateEXT,
       SPIRV_OC_OpEmitMeshTasksEXT, SPIRV_OC_OpSetMeshOutputsEXT,
       SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
       SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index c5a85f881b35e..0a5b01fe9e8d0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -135,6 +135,52 @@ def SPIRV_ConstantOp : SPIRV_Op<"Constant",
   let autogenSerialization = 0;
 }
 
+
+// -----
+
+def SPIRV_EXTConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"ConstantCompositeReplicate", [Pure]> {
+  let summary = [{
+    Declare a new replicated composite constant op.
+  }];
+
+  let description = [{
+    This op declares a `spiv.EXT.ConstantCompositeReplicate` which represents a
+    splat composite constant i.e. all element of composite constant have the
+    same value. This op will be serialized to SPIR-V `OpConstantCompositeReplicateEXT`.
+    The splat value must come from a non-specialization constant instruction."
+
+    #### Example:
+
+    ```mlir
+    %0 = spirv.Constant 1 : i32
+    %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xf32>
+
+    %2 = spirv.Constant dense<[1, 2]> : vector<2xi32>
+    %3 = spirv.EXT.ConstantCompositeReplicate %2 : !spirv.array<2 x vector<2xi32>>
+
+    %5 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
+    %6 = spirv.EXT.ConstantCompositeReplicate %5 : !spirv.array<2 x vector<2xi32>>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_EXT_replicated_composites]>,
+    Capability<[SPIRV_C_ReplicatedCompositesEXT]>
+  ];
+
+  let arguments = (ins
+    SPIRV_Type:$constant
+  );
+
+  let results = (outs
+    SPIRV_Composite:$replicated_constant
+  );
+
+  let autogenSerialization = 0;
+}
+
 // -----
 
 def SPIRV_EntryPointOp : SPIRV_Op<"EntryPoint", [InModuleScope]> {
@@ -689,6 +735,46 @@ def SPIRV_SpecConstantCompositeOp : SPIRV_Op<"SpecConstantComposite", [
 
 // -----
 
+def SPIRV_EXTSpecConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"SpecConstantCompositeReplicate", [InModuleScope, Symbol]> {
+  let summary = "Declare a new replicated composite specialization constant op.";
+
+  let description = [{
+    This op declares a `spirv.EXT.SpecConstantCompositeReplicate` which represents
+    a splat specialization composite constant i.e. all element of specialization
+    composite constant have the same value. This op will be serialized to SPIR-V
+    `OpSpecConstantCompositeReplicateEXT`. The splat value must come from a
+    symbol reference of specialization constant instruction.
+
+    #### Example:
+
+    ```mlir
+    spirv.SpecConstant @sc_i32_1 = 1 : i32
+    spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+    spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_EXT_replicated_composites]>,
+    Capability<[SPIRV_C_ReplicatedCompositesEXT]>
+  ];
+
+  let arguments = (ins
+    TypeAttr:$type,
+    StrAttr:$sym_name,
+    SymbolRefAttr:$constituent
+  );
+
+  let results = (outs);
+
+  let autogenSerialization = 0;
+
+}
+
+// -----
+
 def SPIRV_SpecConstantOperationOp : SPIRV_Op<"SpecConstantOperation", [
        Pure, InFunctionScope,
        SingleBlockImplicitTerminator<"YieldOp">]> {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index eb2974d62fdd1..c42b2d45d53a9 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -765,6 +765,67 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
   setNameFn(getResult(), specialName.str());
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.EXTConstantCompositeReplicate
+//===----------------------------------------------------------------------===//
+
+ParseResult
+spirv::EXTConstantCompositeReplicateOp::parse(OpAsmParser &parser,
+                                              OperationState &result) {
+  OpAsmParser::UnresolvedOperand constOperand;
+  Type compositeType;
+  if (parser.parseOperand(constOperand) ||
+      parser.parseColonType(compositeType)) {
+    return failure();
+  }
+
+  if (llvm::isa<TensorType>(compositeType)) {
+    if (parser.parseColonType(compositeType))
+      return failure();
+  }
+
+  auto constType = cast<spirv::CompositeType>(compositeType).getElementType(0);
+  while (auto type = llvm::dyn_cast<spirv::ArrayType>(constType)) {
+    constType = type.getElementType();
+  }
+
+  if (parser.resolveOperand(constOperand, constType, result.operands))
+    return failure();
+
+  return parser.addTypeToList(compositeType, result.types);
+}
+
+void spirv::EXTConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
+  printer << ' ' << getConstant() << " : " << getType();
+}
+
+LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
+  auto compositeType = dyn_cast<spirv::CompositeType>(getType());
+  if (!compositeType)
+    return emitError("result type must be a composite type, but provided ")
+           << getType();
+
+  auto constantDefiningOp = getConstant().getDefiningOp();
+  if (!constantDefiningOp)
+    return this->emitOpError("op defining the splat constant not found");
+
+  auto constantOp = dyn_cast_or_null<spirv::ConstantOp>(constantDefiningOp);
+  auto constantCompositeReplicateOp =
+      dyn_cast_or_null<spirv::EXTConstantCompositeReplicateOp>(
+          constantDefiningOp);
+
+  if (!constantOp && !constantCompositeReplicateOp)
+    return this->emitOpError(
+        "op defining the splat constant is not a spirv.Constant or a "
+        "spirv.EXT.ConstantCompositeReplicate");
+
+  if (constantOp)
+    return verifyConstantType(constantOp, constantOp.getValueAttr(),
+                              constantOp.getType());
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.ControlBarrierOp
 //===----------------------------------------------------------------------===//
@@ -1866,6 +1927,64 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.EXTSpecConstantCompositeReplicateOp
+//===----------------------------------------------------------------------===//
+
+ParseResult
+spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
+                                                  OperationState &result) {
+
+  StringAttr compositeName;
+  const char *attrName = "spec_const";
+  FlatSymbolRefAttr specConstRef;
+  NamedAttrList attrs;
+  Type type;
+
+  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
+                             result.attributes) ||
+      parser.parseLParen() ||
+      parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
+      parser.parseRParen() || parser.parseColonType(type))
+    return failure();
+
+  StringAttr compositeSpecConstituentName =
+      spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
+          result.name);
+  result.addAttribute(compositeSpecConstituentName, specConstRef);
+
+  StringAttr typeAttrName =
+      spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
+  result.addAttribute(typeAttrName, TypeAttr::get(type));
+
+  return success();
+}
+
+void spirv::EXTSpecConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
+  printer << " ";
+  printer.printSymbolName(getSymName());
+  printer << " (" << this->getConstituent() << ") : " << getType();
+}
+
+LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
+  auto compositeType = dyn_cast<spirv::CompositeType>(getType());
+  if (!compositeType)
+    return emitError("result type must be a composite type, but provided ")
+           << getType();
+
+  auto constituentSpecConstOp =
+      dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
+          (*this)->getParentOp(), this->getConstituent()));
+
+  auto constituentType = constituentSpecConstOp.getDefaultValue().getType();
+  auto compositeElemType = compositeType.getElementType(0);
+  if (constituentType != compositeElemType)
+    return emitError("constituent has incorrect type: expected ")
+           << compositeElemType << ", but provided " << constituentType;
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.SpecConstantOperation
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 55d6a380d0bff..5f52308b4be35 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -45,6 +45,12 @@ Value spirv::Deserializer::getValue(uint32_t id) {
     return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
                                                constInfo->first);
   }
+  if (auto constCompositeReplicateInfo = getConstantCompositeReplicate(id)) {
+    auto constantId = constCompositeReplicateInfo->first;
+    auto element = getValue(constantId);
+    return opBuilder.create<spirv::EXTConstantCompositeReplicateOp>(
+        unknownLoc, constCompositeReplicateInfo->second, element);
+  }
   if (auto varOp = getGlobalVariable(id)) {
     auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
         unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
@@ -56,10 +62,17 @@ Value spirv::Deserializer::getValue(uint32_t id) {
         SymbolRefAttr::get(constOp.getOperation()));
     return referenceOfOp.getReference();
   }
-  if (auto constCompositeOp = getSpecConstantComposite(id)) {
+  if (auto specConstCompositeOp = getSpecConstantComposite(id)) {
+    auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
+        unknownLoc, specConstCompositeOp.getType(),
+        SymbolRefAttr::get(specConstCompositeOp.getOperation()));
+    return referenceOfOp.getReference();
+  }
+  if (auto specConstCompositeReplicateOp =
+          getSpecConstantCompositeReplicate(id)) {
     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
-        unknownLoc, constCompositeOp.getType(),
-        SymbolRefAttr::get(constCompositeOp.getOperation()));
+        unknownLoc, specConstCompositeReplicateOp.getType(),
+        SymbolRefAttr::get(specConstCompositeReplicateOp.getOperation()));
     return referenceOfOp.getReference();
   }
   if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
@@ -175,8 +188,12 @@ LogicalResult spirv::Deserializer::processInstruction(
     return processConstant(operands, /*isSpec=*/true);
   case spirv::Opcode::OpConstantComposite:
     return processConstantComposite(operands);
+  case spirv::Opcode::OpConstantCompositeReplicateEXT:
+    return processConstantCompositeReplicateEXT(operands);
   case spirv::Opcode::OpSpecConstantComposite:
     return processSpecConstantComposite(operands);
+  case spirv::Opcode::OpSpecConstantCompositeReplicateEXT:
+    return processSpecConstantCompositeReplicateEXT(operands);
   case spirv::Opcode::OpSpecConstantOp:
     return processSpecConstantOperation(operands);
   case spirv::Opcode::OpConstantTrue:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b1abd8b3dffe9..2163ccff93c83 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -678,6 +678,14 @@ spirv::Deserializer::getConstant(uint32_t id) {
   return constIt->getSecond();
 }
 
+std::optional<std::pair<uint32_t, Type>>
+spirv::Deserializer::getConstantCompositeReplicate(uint32_t id) {
+  auto constIt = constantCompositeReplicateMap.find(id);
+  if (constIt == constantCompositeReplicateMap.end())
+    return std::nullopt;
+  return constIt->getSecond();
+}
+
 std::optional<spirv::SpecConstOperationMaterializationInfo>
 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
   auto constIt = specConstOperationMap.find(id);
@@ -1554,15 +1562,58 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
+    ArrayRef<uint32_t> operands) {
+
+  if (operands.size() != 3) {
+    return emitError(
+        unknownLoc,
+        "OpConstantCompositeReplicateEXT must have type <id> and result <id> "
+        "and only one parameter which is <id> of splat constant");
+  }
+
+  Type resultType = getType(operands[0]);
+  if (!resultType) {
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << operands[0];
+  }
+
+  auto compositeType = dyn_cast<CompositeType>(resultType);
+  if (!compositeType) {
+    return emitError(unknownLoc,
+                     "result type from <id> is not a composite type")
+           << operands[0];
+  }
+
+  auto resultID = operands[1];
+  auto constantID = operands[2];
+
+  auto constantInfo = getConstant(constantID);
+  auto replicatedConstantCompositeInfo =
+      getConstantCompositeReplicate(constantID);
+  if (!constantInfo && !replicatedConstantCompositeInfo) {
+    return emitError(unknownLoc,
+                     "OpConstantCompositeReplicateEXT operand <id> ")
+           << constantID
+           << " must come from a normal constant or a "
+              "OpConstantCompositeReplicateEXT";
+  }
+
+  constantCompositeReplicateMap.try_emplace(resultID, constantID, resultType);
+
+  return success();
+}
+
 LogicalResult
 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
   if (operands.size() < 2) {
-    return emitError(unknownLoc,
-                     "OpConstantComposite must have type <id> and result <id>");
+    return emitError(
+        unknownLoc,
+        "OpSpecConstantComposite must have type <id> and result <id>");
   }
   if (operands.size() < 3) {
     return emitError(unknownLoc,
-                     "OpConstantComposite must have at least 1 parameter");
+                     "OpSpecConstantComposite must have at least 1 parameter");
   }
 
   Type resultType = getType(operands[0]);
@@ -1589,6 +1640,42 @@ spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult spirv::Deserializer::processSpecConstantCompositeReplicateEXT(
+    ArrayRef<uint32_t> operands) {
+
+  if (operands.size() != 3) {
+    return emitError(unknownLoc,
+                     "OpSpecConstantCompositeReplicateEXT must have "
+                     "type <id> and result <id> and only one parameter which "
+                     "is <id> of splat constant");
+  }
+
+  Type resultType = getType(operands[0]);
+  if (!resultType) {
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << operands[0];
+  }
+
+  auto compositeType = dyn_cast<CompositeType>(resultType);
+  if (!compositeType) {
+    return emitError(unknownLoc,
+                     "result type from <id> is not a composite type")
+           << operands[0];
+  }
+
+  auto resultID = operands[1];
+
+  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
+  auto constituentSpecConstantOp = getSpecConstant(operands[2]);
+  auto op = opBuilder.create<spirv::EXTSpecConstantCompositeReplicateOp>(
+      unknownLoc, TypeAttr::get(resultType), symName,
+      SymbolRefAttr::get(constituentSpecConstantOp));
+
+  specConstCompositeReplicateMap[resultID] = op;
+
+  return success();
+}
+
 LogicalResult
 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
   if (operands.size() < 3)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 1bc9e4a3c75d8..1fdecc3e6fe0d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -190,6 +190,12 @@ class Deseria...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2025

@llvm/pr-subscribers-mlir

Author: Mohammadreza Ameri Mahabadian (mahabadm)

Changes

This patch introduces two new ops to the SPIR-V dialect:

  • spirv.EXT.ConstantCompositeReplicate
  • spirv.EXT.SpecConstantCompositeReplicate

These ops represent composite constants and specialization constants, respectively, constructed by replicating a single splat constant across all elements. They correspond to SPV_EXT_replicated_composites extension instructions:

  • OpConstantCompositeReplicatedEXT
  • OpSpecConstantCompositeReplicatedEXT

No transformation to these new ops has been introduced in this patch.

This approach is chosen as per the discussions on RFC https://discourse.llvm.org/t/rfc-basic-support-for-spv-ext-replicated-composites-in-mlir-spir-v-compile-time-constant-lowering-only/86987


Patch is 36.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147067.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+13-2)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td (+86)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+119)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp (+20-3)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+90-3)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+39)
  • (modified) mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp (+36)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+46)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.h (+12)
  • (modified) mlir/test/Target/SPIRV/constant.mlir (+82-1)
  • (modified) mlir/test/Target/SPIRV/spec-constant.mlir (+27)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d874817e6888d..6c24dbc613c82 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -359,6 +359,7 @@ def SPV_EXT_shader_atomic_float_min_max  : I32EnumAttrCase<"SPV_EXT_shader_atomi
 def SPV_EXT_shader_image_int64           : I32EnumAttrCase<"SPV_EXT_shader_image_int64", 1010>;
 def SPV_EXT_shader_atomic_float16_add    : I32EnumAttrCase<"SPV_EXT_shader_atomic_float16_add", 1011>;
 def SPV_EXT_mesh_shader                  : I32EnumAttrCase<"SPV_EXT_mesh_shader", 1012>;
+def SPV_EXT_replicated_composites        : I32EnumAttrCase<"SPV_EXT_replicated_composites", 1013>;
 
 def SPV_AMD_gpu_shader_half_float_fetch          : I32EnumAttrCase<"SPV_AMD_gpu_shader_half_float_fetch", 2000>;
 def SPV_AMD_shader_ballot                        : I32EnumAttrCase<"SPV_AMD_shader_ballot", 2001>;
@@ -446,7 +447,7 @@ def SPIRV_ExtensionAttr :
       SPV_EXT_shader_stencil_export, SPV_EXT_shader_viewport_index_layer,
       SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
       SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
-      SPV_EXT_mesh_shader,
+      SPV_EXT_mesh_shader, SPV_EXT_replicated_composites,
       SPV_ARM_tensors,
       SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
       SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
@@ -849,6 +850,12 @@ def SPIRV_C_CooperativeMatrixKHR                        : I32EnumAttrCase<"Coope
     MinVersion<SPIRV_V_1_6>
   ];
 }
+def SPIRV_C_ReplicatedCompositesEXT                     : I32EnumAttrCase<"ReplicatedCompositesEXT", 6024> {
+  list<Availability> availability = [
+    Extension<[SPV_EXT_replicated_composites]>,
+    MinVersion<SPIRV_V_1_0>
+  ];
+}
 def SPIRV_C_BitInstructions                             : I32EnumAttrCase<"BitInstructions", 6025> {
   list<Availability> availability = [
     Extension<[SPV_KHR_bit_instructions]>
@@ -1500,7 +1507,7 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_USMStorageClassesINTEL, SPIRV_C_IOPipesINTEL, SPIRV_C_BlockingPipesINTEL,
       SPIRV_C_FPGARegINTEL, SPIRV_C_DotProductInputAll,
       SPIRV_C_DotProductInput4x8BitPacked, SPIRV_C_DotProduct, SPIRV_C_RayCullMaskKHR,
-      SPIRV_C_CooperativeMatrixKHR,
+      SPIRV_C_CooperativeMatrixKHR, SPIRV_C_ReplicatedCompositesEXT,
       SPIRV_C_BitInstructions, SPIRV_C_AtomicFloat32AddEXT, SPIRV_C_AtomicFloat64AddEXT,
       SPIRV_C_LongConstantCompositeINTEL, SPIRV_C_OptNoneINTEL,
       SPIRV_C_AtomicFloat16AddEXT, SPIRV_C_DebugInfoModuleINTEL, SPIRV_C_SplitBarrierINTEL,
@@ -4564,6 +4571,8 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR       : I32EnumAttrCase<"OpCooperativeMa
 def SPIRV_OC_OpCooperativeMatrixStoreKHR      : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
 def SPIRV_OC_OpCooperativeMatrixMulAddKHR     : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
 def SPIRV_OC_OpCooperativeMatrixLengthKHR     : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
+def SPIRV_OC_OpConstantCompositeReplicateEXT : I32EnumAttrCase<"OpConstantCompositeReplicateEXT", 4461>;
+def SPIRV_OC_OpSpecConstantCompositeReplicateEXT : I32EnumAttrCase<"OpSpecConstantCompositeReplicateEXT", 4462>;
 def SPIRV_OC_OpEmitMeshTasksEXT               : I32EnumAttrCase<"OpEmitMeshTasksEXT", 5294>;
 def SPIRV_OC_OpSetMeshOutputsEXT              : I32EnumAttrCase<"OpSetMeshOutputsEXT", 5295>;
 def SPIRV_OC_OpSubgroupBlockReadINTEL         : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
@@ -4672,6 +4681,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
       SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixStoreKHR,
       SPIRV_OC_OpCooperativeMatrixMulAddKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
+      SPIRV_OC_OpConstantCompositeReplicateEXT,
+      SPIRV_OC_OpSpecConstantCompositeReplicateEXT,
       SPIRV_OC_OpEmitMeshTasksEXT, SPIRV_OC_OpSetMeshOutputsEXT,
       SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
       SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index c5a85f881b35e..0a5b01fe9e8d0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -135,6 +135,52 @@ def SPIRV_ConstantOp : SPIRV_Op<"Constant",
   let autogenSerialization = 0;
 }
 
+
+// -----
+
+def SPIRV_EXTConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"ConstantCompositeReplicate", [Pure]> {
+  let summary = [{
+    Declare a new replicated composite constant op.
+  }];
+
+  let description = [{
+    This op declares a `spiv.EXT.ConstantCompositeReplicate` which represents a
+    splat composite constant i.e. all element of composite constant have the
+    same value. This op will be serialized to SPIR-V `OpConstantCompositeReplicateEXT`.
+    The splat value must come from a non-specialization constant instruction."
+
+    #### Example:
+
+    ```mlir
+    %0 = spirv.Constant 1 : i32
+    %1 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xf32>
+
+    %2 = spirv.Constant dense<[1, 2]> : vector<2xi32>
+    %3 = spirv.EXT.ConstantCompositeReplicate %2 : !spirv.array<2 x vector<2xi32>>
+
+    %5 = spirv.EXT.ConstantCompositeReplicate %0 : vector<2xi32>
+    %6 = spirv.EXT.ConstantCompositeReplicate %5 : !spirv.array<2 x vector<2xi32>>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_EXT_replicated_composites]>,
+    Capability<[SPIRV_C_ReplicatedCompositesEXT]>
+  ];
+
+  let arguments = (ins
+    SPIRV_Type:$constant
+  );
+
+  let results = (outs
+    SPIRV_Composite:$replicated_constant
+  );
+
+  let autogenSerialization = 0;
+}
+
 // -----
 
 def SPIRV_EntryPointOp : SPIRV_Op<"EntryPoint", [InModuleScope]> {
@@ -689,6 +735,46 @@ def SPIRV_SpecConstantCompositeOp : SPIRV_Op<"SpecConstantComposite", [
 
 // -----
 
+def SPIRV_EXTSpecConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"SpecConstantCompositeReplicate", [InModuleScope, Symbol]> {
+  let summary = "Declare a new replicated composite specialization constant op.";
+
+  let description = [{
+    This op declares a `spirv.EXT.SpecConstantCompositeReplicate` which represents
+    a splat specialization composite constant i.e. all element of specialization
+    composite constant have the same value. This op will be serialized to SPIR-V
+    `OpSpecConstantCompositeReplicateEXT`. The splat value must come from a
+    symbol reference of specialization constant instruction.
+
+    #### Example:
+
+    ```mlir
+    spirv.SpecConstant @sc_i32_1 = 1 : i32
+    spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+    spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_EXT_replicated_composites]>,
+    Capability<[SPIRV_C_ReplicatedCompositesEXT]>
+  ];
+
+  let arguments = (ins
+    TypeAttr:$type,
+    StrAttr:$sym_name,
+    SymbolRefAttr:$constituent
+  );
+
+  let results = (outs);
+
+  let autogenSerialization = 0;
+
+}
+
+// -----
+
 def SPIRV_SpecConstantOperationOp : SPIRV_Op<"SpecConstantOperation", [
        Pure, InFunctionScope,
        SingleBlockImplicitTerminator<"YieldOp">]> {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index eb2974d62fdd1..c42b2d45d53a9 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -765,6 +765,67 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
   setNameFn(getResult(), specialName.str());
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.EXTConstantCompositeReplicate
+//===----------------------------------------------------------------------===//
+
+ParseResult
+spirv::EXTConstantCompositeReplicateOp::parse(OpAsmParser &parser,
+                                              OperationState &result) {
+  OpAsmParser::UnresolvedOperand constOperand;
+  Type compositeType;
+  if (parser.parseOperand(constOperand) ||
+      parser.parseColonType(compositeType)) {
+    return failure();
+  }
+
+  if (llvm::isa<TensorType>(compositeType)) {
+    if (parser.parseColonType(compositeType))
+      return failure();
+  }
+
+  auto constType = cast<spirv::CompositeType>(compositeType).getElementType(0);
+  while (auto type = llvm::dyn_cast<spirv::ArrayType>(constType)) {
+    constType = type.getElementType();
+  }
+
+  if (parser.resolveOperand(constOperand, constType, result.operands))
+    return failure();
+
+  return parser.addTypeToList(compositeType, result.types);
+}
+
+void spirv::EXTConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
+  printer << ' ' << getConstant() << " : " << getType();
+}
+
+LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
+  auto compositeType = dyn_cast<spirv::CompositeType>(getType());
+  if (!compositeType)
+    return emitError("result type must be a composite type, but provided ")
+           << getType();
+
+  auto constantDefiningOp = getConstant().getDefiningOp();
+  if (!constantDefiningOp)
+    return this->emitOpError("op defining the splat constant not found");
+
+  auto constantOp = dyn_cast_or_null<spirv::ConstantOp>(constantDefiningOp);
+  auto constantCompositeReplicateOp =
+      dyn_cast_or_null<spirv::EXTConstantCompositeReplicateOp>(
+          constantDefiningOp);
+
+  if (!constantOp && !constantCompositeReplicateOp)
+    return this->emitOpError(
+        "op defining the splat constant is not a spirv.Constant or a "
+        "spirv.EXT.ConstantCompositeReplicate");
+
+  if (constantOp)
+    return verifyConstantType(constantOp, constantOp.getValueAttr(),
+                              constantOp.getType());
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.ControlBarrierOp
 //===----------------------------------------------------------------------===//
@@ -1866,6 +1927,64 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.EXTSpecConstantCompositeReplicateOp
+//===----------------------------------------------------------------------===//
+
+ParseResult
+spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
+                                                  OperationState &result) {
+
+  StringAttr compositeName;
+  const char *attrName = "spec_const";
+  FlatSymbolRefAttr specConstRef;
+  NamedAttrList attrs;
+  Type type;
+
+  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
+                             result.attributes) ||
+      parser.parseLParen() ||
+      parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
+      parser.parseRParen() || parser.parseColonType(type))
+    return failure();
+
+  StringAttr compositeSpecConstituentName =
+      spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
+          result.name);
+  result.addAttribute(compositeSpecConstituentName, specConstRef);
+
+  StringAttr typeAttrName =
+      spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
+  result.addAttribute(typeAttrName, TypeAttr::get(type));
+
+  return success();
+}
+
+void spirv::EXTSpecConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
+  printer << " ";
+  printer.printSymbolName(getSymName());
+  printer << " (" << this->getConstituent() << ") : " << getType();
+}
+
+LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
+  auto compositeType = dyn_cast<spirv::CompositeType>(getType());
+  if (!compositeType)
+    return emitError("result type must be a composite type, but provided ")
+           << getType();
+
+  auto constituentSpecConstOp =
+      dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
+          (*this)->getParentOp(), this->getConstituent()));
+
+  auto constituentType = constituentSpecConstOp.getDefaultValue().getType();
+  auto compositeElemType = compositeType.getElementType(0);
+  if (constituentType != compositeElemType)
+    return emitError("constituent has incorrect type: expected ")
+           << compositeElemType << ", but provided " << constituentType;
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.SpecConstantOperation
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 55d6a380d0bff..5f52308b4be35 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -45,6 +45,12 @@ Value spirv::Deserializer::getValue(uint32_t id) {
     return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
                                                constInfo->first);
   }
+  if (auto constCompositeReplicateInfo = getConstantCompositeReplicate(id)) {
+    auto constantId = constCompositeReplicateInfo->first;
+    auto element = getValue(constantId);
+    return opBuilder.create<spirv::EXTConstantCompositeReplicateOp>(
+        unknownLoc, constCompositeReplicateInfo->second, element);
+  }
   if (auto varOp = getGlobalVariable(id)) {
     auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
         unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
@@ -56,10 +62,17 @@ Value spirv::Deserializer::getValue(uint32_t id) {
         SymbolRefAttr::get(constOp.getOperation()));
     return referenceOfOp.getReference();
   }
-  if (auto constCompositeOp = getSpecConstantComposite(id)) {
+  if (auto specConstCompositeOp = getSpecConstantComposite(id)) {
+    auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
+        unknownLoc, specConstCompositeOp.getType(),
+        SymbolRefAttr::get(specConstCompositeOp.getOperation()));
+    return referenceOfOp.getReference();
+  }
+  if (auto specConstCompositeReplicateOp =
+          getSpecConstantCompositeReplicate(id)) {
     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
-        unknownLoc, constCompositeOp.getType(),
-        SymbolRefAttr::get(constCompositeOp.getOperation()));
+        unknownLoc, specConstCompositeReplicateOp.getType(),
+        SymbolRefAttr::get(specConstCompositeReplicateOp.getOperation()));
     return referenceOfOp.getReference();
   }
   if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
@@ -175,8 +188,12 @@ LogicalResult spirv::Deserializer::processInstruction(
     return processConstant(operands, /*isSpec=*/true);
   case spirv::Opcode::OpConstantComposite:
     return processConstantComposite(operands);
+  case spirv::Opcode::OpConstantCompositeReplicateEXT:
+    return processConstantCompositeReplicateEXT(operands);
   case spirv::Opcode::OpSpecConstantComposite:
     return processSpecConstantComposite(operands);
+  case spirv::Opcode::OpSpecConstantCompositeReplicateEXT:
+    return processSpecConstantCompositeReplicateEXT(operands);
   case spirv::Opcode::OpSpecConstantOp:
     return processSpecConstantOperation(operands);
   case spirv::Opcode::OpConstantTrue:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b1abd8b3dffe9..2163ccff93c83 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -678,6 +678,14 @@ spirv::Deserializer::getConstant(uint32_t id) {
   return constIt->getSecond();
 }
 
+std::optional<std::pair<uint32_t, Type>>
+spirv::Deserializer::getConstantCompositeReplicate(uint32_t id) {
+  auto constIt = constantCompositeReplicateMap.find(id);
+  if (constIt == constantCompositeReplicateMap.end())
+    return std::nullopt;
+  return constIt->getSecond();
+}
+
 std::optional<spirv::SpecConstOperationMaterializationInfo>
 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
   auto constIt = specConstOperationMap.find(id);
@@ -1554,15 +1562,58 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
+    ArrayRef<uint32_t> operands) {
+
+  if (operands.size() != 3) {
+    return emitError(
+        unknownLoc,
+        "OpConstantCompositeReplicateEXT must have type <id> and result <id> "
+        "and only one parameter which is <id> of splat constant");
+  }
+
+  Type resultType = getType(operands[0]);
+  if (!resultType) {
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << operands[0];
+  }
+
+  auto compositeType = dyn_cast<CompositeType>(resultType);
+  if (!compositeType) {
+    return emitError(unknownLoc,
+                     "result type from <id> is not a composite type")
+           << operands[0];
+  }
+
+  auto resultID = operands[1];
+  auto constantID = operands[2];
+
+  auto constantInfo = getConstant(constantID);
+  auto replicatedConstantCompositeInfo =
+      getConstantCompositeReplicate(constantID);
+  if (!constantInfo && !replicatedConstantCompositeInfo) {
+    return emitError(unknownLoc,
+                     "OpConstantCompositeReplicateEXT operand <id> ")
+           << constantID
+           << " must come from a normal constant or a "
+              "OpConstantCompositeReplicateEXT";
+  }
+
+  constantCompositeReplicateMap.try_emplace(resultID, constantID, resultType);
+
+  return success();
+}
+
 LogicalResult
 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
   if (operands.size() < 2) {
-    return emitError(unknownLoc,
-                     "OpConstantComposite must have type <id> and result <id>");
+    return emitError(
+        unknownLoc,
+        "OpSpecConstantComposite must have type <id> and result <id>");
   }
   if (operands.size() < 3) {
     return emitError(unknownLoc,
-                     "OpConstantComposite must have at least 1 parameter");
+                     "OpSpecConstantComposite must have at least 1 parameter");
   }
 
   Type resultType = getType(operands[0]);
@@ -1589,6 +1640,42 @@ spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult spirv::Deserializer::processSpecConstantCompositeReplicateEXT(
+    ArrayRef<uint32_t> operands) {
+
+  if (operands.size() != 3) {
+    return emitError(unknownLoc,
+                     "OpSpecConstantCompositeReplicateEXT must have "
+                     "type <id> and result <id> and only one parameter which "
+                     "is <id> of splat constant");
+  }
+
+  Type resultType = getType(operands[0]);
+  if (!resultType) {
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << operands[0];
+  }
+
+  auto compositeType = dyn_cast<CompositeType>(resultType);
+  if (!compositeType) {
+    return emitError(unknownLoc,
+                     "result type from <id> is not a composite type")
+           << operands[0];
+  }
+
+  auto resultID = operands[1];
+
+  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
+  auto constituentSpecConstantOp = getSpecConstant(operands[2]);
+  auto op = opBuilder.create<spirv::EXTSpecConstantCompositeReplicateOp>(
+      unknownLoc, TypeAttr::get(resultType), symName,
+      SymbolRefAttr::get(constituentSpecConstantOp));
+
+  specConstCompositeReplicateMap[resultID] = op;
+
+  return success();
+}
+
 LogicalResult
 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
   if (operands.size() < 3)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 1bc9e4a3c75d8..1fdecc3e6fe0d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -190,6 +190,12 @@ class Deseria...
[truncated]

@kuhar kuhar requested a review from IgWod-IMG July 4, 2025 14:48
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for submitting this PR. Few general and specific comments from me:

  1. We probably should also add tests to: https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/SPIRV/IR/availability.mlir

  2. I wonder whether it would make sense to add some tests to https://github.com/llvm/llvm-project/tree/main/mlir/test/Dialect/SPIRV/IR including some negative tests?

  3. Regarding the use of braces; I feel like in some places they could be dropped. See: https://llvm.org/docs/CodingStandards.html#don-t-use-braces-on-simple-single-statement-bodies-of-if-else-loop-statements But I leave the decision up to you.

Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@mahabadm
Copy link
Contributor Author

mahabadm commented Jul 7, 2025

Thank you for submitting this PR. Few general and specific comments from me:

1. We probably should also add tests to: https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/SPIRV/IR/availability.mlir

2. I wonder whether it would make sense to add some tests to https://github.com/llvm/llvm-project/tree/main/mlir/test/Dialect/SPIRV/IR including some negative tests?

3. Regarding the use of braces; I feel like in some places they could be dropped. See: https://llvm.org/docs/CodingStandards.html#don-t-use-braces-on-simple-single-statement-bodies-of-if-else-loop-statements But I leave the decision up to you.

Many thanks for your comments @IgWod-IMG !
I have now added several negative tests and also a test in availability.mlir. These were great suggestions resulting in improving the patch even further.

Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left extra two small comments, but I will wait with the final review until you address kuhar's comments, as some of the code may change.

@mahabadm mahabadm requested a review from kuhar July 10, 2025 07:23
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No more comments from me.

Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@kuhar kuhar requested a review from Hardcode84 July 11, 2025 14:13
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@mahabadm mahabadm requested a review from kuhar July 11, 2025 16:25
@mahabadm
Copy link
Contributor Author

@kuhar Many thanks again for the great code review. Please let me know if there's anything else needed here. Otherwise, if all looks good, would it be possible to merge please? Thank you.

Copy link
Contributor

@Hardcode84 Hardcode84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, no more comments from my side

@mahabadm
Copy link
Contributor Author

@kuhar I appreciate if you can please update me if there is anything else left to be done from my side on this patch. Thank you.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the changes, LGTM. Have you checked if it passes validation with spirv-val? If so, we can merge.

@mahabadm
Copy link
Contributor Author

@kuhar Many thanks. Yes, I have checked the results with spirv validator.

@kuhar kuhar merged commit 94b15a1 into llvm:main Jul 15, 2025
9 checks passed
Copy link

@mahabadm Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

@mahabadm mahabadm deleted the add_basic_support_for_SPV_EXT_replicated_composites branch July 15, 2025 14:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants