Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def SPV_EXT_shader_atomic_float_min_max : I32EnumAttrCase<"SPV_EXT_shader_atomi
def SPV_EXT_shader_image_int64 : I32EnumAttrCase<"SPV_EXT_shader_image_int64", 1010>;
def SPV_EXT_shader_atomic_float16_add : I32EnumAttrCase<"SPV_EXT_shader_atomic_float16_add", 1011>;
def SPV_EXT_mesh_shader : I32EnumAttrCase<"SPV_EXT_mesh_shader", 1012>;
def SPV_EXT_replicated_composites : I32EnumAttrCase<"SPV_EXT_replicated_composites", 1013>;

def SPV_AMD_gpu_shader_half_float_fetch : I32EnumAttrCase<"SPV_AMD_gpu_shader_half_float_fetch", 2000>;
def SPV_AMD_shader_ballot : I32EnumAttrCase<"SPV_AMD_shader_ballot", 2001>;
Expand Down Expand Up @@ -446,7 +447,7 @@ def SPIRV_ExtensionAttr :
SPV_EXT_shader_stencil_export, SPV_EXT_shader_viewport_index_layer,
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
SPV_EXT_mesh_shader,
SPV_EXT_mesh_shader, SPV_EXT_replicated_composites,
SPV_ARM_tensors,
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
Expand Down Expand Up @@ -849,6 +850,12 @@ def SPIRV_C_CooperativeMatrixKHR : I32EnumAttrCase<"Coope
MinVersion<SPIRV_V_1_6>
];
}
def SPIRV_C_ReplicatedCompositesEXT : I32EnumAttrCase<"ReplicatedCompositesEXT", 6024> {
list<Availability> availability = [
Extension<[SPV_EXT_replicated_composites]>,
MinVersion<SPIRV_V_1_0>
];
}
def SPIRV_C_BitInstructions : I32EnumAttrCase<"BitInstructions", 6025> {
list<Availability> availability = [
Extension<[SPV_KHR_bit_instructions]>
Expand Down Expand Up @@ -1500,7 +1507,7 @@ def SPIRV_CapabilityAttr :
SPIRV_C_USMStorageClassesINTEL, SPIRV_C_IOPipesINTEL, SPIRV_C_BlockingPipesINTEL,
SPIRV_C_FPGARegINTEL, SPIRV_C_DotProductInputAll,
SPIRV_C_DotProductInput4x8BitPacked, SPIRV_C_DotProduct, SPIRV_C_RayCullMaskKHR,
SPIRV_C_CooperativeMatrixKHR,
SPIRV_C_CooperativeMatrixKHR, SPIRV_C_ReplicatedCompositesEXT,
SPIRV_C_BitInstructions, SPIRV_C_AtomicFloat32AddEXT, SPIRV_C_AtomicFloat64AddEXT,
SPIRV_C_LongConstantCompositeINTEL, SPIRV_C_OptNoneINTEL,
SPIRV_C_AtomicFloat16AddEXT, SPIRV_C_DebugInfoModuleINTEL, SPIRV_C_SplitBarrierINTEL,
Expand Down Expand Up @@ -4564,6 +4571,8 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMa
def SPIRV_OC_OpCooperativeMatrixStoreKHR : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
def SPIRV_OC_OpCooperativeMatrixMulAddKHR : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
def SPIRV_OC_OpConstantCompositeReplicateEXT : I32EnumAttrCase<"OpConstantCompositeReplicateEXT", 4461>;
def SPIRV_OC_OpSpecConstantCompositeReplicateEXT : I32EnumAttrCase<"OpSpecConstantCompositeReplicateEXT", 4462>;
def SPIRV_OC_OpEmitMeshTasksEXT : I32EnumAttrCase<"OpEmitMeshTasksEXT", 5294>;
def SPIRV_OC_OpSetMeshOutputsEXT : I32EnumAttrCase<"OpSetMeshOutputsEXT", 5295>;
def SPIRV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
Expand Down Expand Up @@ -4672,6 +4681,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixStoreKHR,
SPIRV_OC_OpCooperativeMatrixMulAddKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
SPIRV_OC_OpConstantCompositeReplicateEXT,
SPIRV_OC_OpSpecConstantCompositeReplicateEXT,
SPIRV_OC_OpEmitMeshTasksEXT, SPIRV_OC_OpSetMeshOutputsEXT,
SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
Expand Down
78 changes: 78 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,47 @@ def SPIRV_ConstantOp : SPIRV_Op<"Constant",
let autogenSerialization = 0;
}


// -----

def SPIRV_EXTConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"ConstantCompositeReplicate", [Pure]> {
let summary = [{
Declare a new replicated composite constant op.
}];

let description = [{
Represents a splat composite constant i.e., all elements of composite constant
have the same value.

#### Example:

```mlir
%0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<2xi32>
%1 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<2 x vector<2xi32>>
%2 = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
```
}];

let availability = [
MinVersion<SPIRV_V_1_0>,
MaxVersion<SPIRV_V_1_6>,
Extension<[SPV_EXT_replicated_composites]>,
Capability<[SPIRV_C_ReplicatedCompositesEXT]>
];

let arguments = (ins
AnyAttr:$value
);

let results = (outs
SPIRV_Composite:$replicated_constant
);

let autogenSerialization = 0;

let assemblyFormat = "` ` `[` $value `]` `:` type($replicated_constant) attr-dict";
}

// -----

def SPIRV_EntryPointOp : SPIRV_Op<"EntryPoint", [InModuleScope]> {
Expand Down Expand Up @@ -689,6 +730,43 @@ def SPIRV_SpecConstantCompositeOp : SPIRV_Op<"SpecConstantComposite", [

// -----

def SPIRV_EXTSpecConstantCompositeReplicateOp : SPIRV_ExtVendorOp<"SpecConstantCompositeReplicate", [InModuleScope, Symbol]> {
let summary = "Declare a new replicated composite specialization constant op.";

let description = [{
Represents a splat spec composite constant i.e., all elements of spec composite
constant have the same value. The splat value must come from a symbol reference
of spec constant instruction.

#### Example:

```mlir
spirv.SpecConstant @sc_i32_1 = 1 : i32
spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
```
}];

let availability = [
MinVersion<SPIRV_V_1_0>,
MaxVersion<SPIRV_V_1_6>,
Extension<[SPV_EXT_replicated_composites]>,
Capability<[SPIRV_C_ReplicatedCompositesEXT]>
];

let arguments = (ins
TypeAttr:$type,
StrAttr:$sym_name,
SymbolRefAttr:$constituent
);

let results = (outs);

let autogenSerialization = 0;
}

// -----

def SPIRV_SpecConstantOperationOp : SPIRV_Op<"SpecConstantOperation", [
Pure, InFunctionScope,
SingleBlockImplicitTerminator<"YieldOp">]> {
Expand Down
101 changes: 101 additions & 0 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,44 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
setNameFn(getResult(), specialName.str());
}

//===----------------------------------------------------------------------===//
// spirv.EXTConstantCompositeReplicate
//===----------------------------------------------------------------------===//

LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
Type valueType;
if (auto typedAttr = dyn_cast<TypedAttr>(getValue())) {
valueType = typedAttr.getType();
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]);
if (!typedElemAttr)
return emitError("value attribute is not typed");
valueType =
spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size());
} else {
return emitError("unknown value attribute type");
}

auto compositeType = dyn_cast<spirv::CompositeType>(getType());
if (!compositeType)
return emitError("result type is not a composite type");

Type compositeElementType = compositeType.getElementType(0);

SmallVector<Type, 3> possibleTypes = {compositeElementType};
while (auto type = dyn_cast<spirv::CompositeType>(compositeElementType)) {
compositeElementType = type.getElementType(0);
possibleTypes.push_back(compositeElementType);
}

if (!is_contained(possibleTypes, valueType)) {
return emitError("expected value attribute type ")
<< interleaved(possibleTypes, " or ") << ", but got: " << valueType;
}

return success();
}

//===----------------------------------------------------------------------===//
// spirv.ControlBarrierOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1866,6 +1904,69 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// spirv.EXTSpecConstantCompositeReplicateOp
//===----------------------------------------------------------------------===//

ParseResult
spirv::EXTSpecConstantCompositeReplicateOp::parse(OpAsmParser &parser,
OperationState &result) {
StringAttr compositeName;
FlatSymbolRefAttr specConstRef;
const char *attrName = "spec_const";
NamedAttrList attrs;
Type type;

if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
result.attributes) ||
parser.parseLParen() ||
parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
parser.parseRParen() || parser.parseColonType(type))
return failure();

StringAttr compositeSpecConstituentName =
spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
result.name);
result.addAttribute(compositeSpecConstituentName, specConstRef);

StringAttr typeAttrName =
spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
result.addAttribute(typeAttrName, TypeAttr::get(type));

return success();
}

void spirv::EXTSpecConstantCompositeReplicateOp::print(OpAsmPrinter &printer) {
printer << " ";
printer.printSymbolName(getSymName());
printer << " (" << this->getConstituent() << ") : " << getType();
}

LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
auto compositeType = dyn_cast<spirv::CompositeType>(getType());
if (!compositeType)
return emitError("result type must be a composite type, but provided ")
<< getType();

Operation *constituentOp = SymbolTable::lookupNearestSymbolFrom(
(*this)->getParentOp(), this->getConstituent());
if (!constituentOp)
return emitError(
"splat spec constant reference defining constituent not found");

auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
if (!constituentSpecConstOp)
return emitError("constituent is not a spec constant");

Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
Type compositeElementType = compositeType.getElementType(0);
if (constituentType != compositeElementType)
return emitError("constituent has incorrect type: expected ")
<< compositeElementType << ", but provided " << constituentType;

return success();
}

//===----------------------------------------------------------------------===//
// spirv.SpecConstantOperation
//===----------------------------------------------------------------------===//
Expand Down
24 changes: 21 additions & 3 deletions mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ Value spirv::Deserializer::getValue(uint32_t id) {
return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
constInfo->first);
}
if (std::optional<std::pair<Attribute, Type>> constCompositeReplicateInfo =
getConstantCompositeReplicate(id)) {
return opBuilder.create<spirv::EXTConstantCompositeReplicateOp>(
unknownLoc, constCompositeReplicateInfo->second,
constCompositeReplicateInfo->first);
}
if (auto varOp = getGlobalVariable(id)) {
auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
Expand All @@ -56,10 +62,18 @@ Value spirv::Deserializer::getValue(uint32_t id) {
SymbolRefAttr::get(constOp.getOperation()));
return referenceOfOp.getReference();
}
if (auto constCompositeOp = getSpecConstantComposite(id)) {
if (SpecConstantCompositeOp specConstCompositeOp =
getSpecConstantComposite(id)) {
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
unknownLoc, specConstCompositeOp.getType(),
SymbolRefAttr::get(specConstCompositeOp.getOperation()));
return referenceOfOp.getReference();
}
if (auto specConstCompositeReplicateOp =
getSpecConstantCompositeReplicate(id)) {
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
unknownLoc, constCompositeOp.getType(),
SymbolRefAttr::get(constCompositeOp.getOperation()));
unknownLoc, specConstCompositeReplicateOp.getType(),
SymbolRefAttr::get(specConstCompositeReplicateOp.getOperation()));
return referenceOfOp.getReference();
}
if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
Expand Down Expand Up @@ -175,8 +189,12 @@ LogicalResult spirv::Deserializer::processInstruction(
return processConstant(operands, /*isSpec=*/true);
case spirv::Opcode::OpConstantComposite:
return processConstantComposite(operands);
case spirv::Opcode::OpConstantCompositeReplicateEXT:
return processConstantCompositeReplicateEXT(operands);
case spirv::Opcode::OpSpecConstantComposite:
return processSpecConstantComposite(operands);
case spirv::Opcode::OpSpecConstantCompositeReplicateEXT:
return processSpecConstantCompositeReplicateEXT(operands);
case spirv::Opcode::OpSpecConstantOp:
return processSpecConstantOperation(operands);
case spirv::Opcode::OpConstantTrue:
Expand Down
Loading