Skip to content

Commit c96e2cd

Browse files
authored
[mlir][XeGPU] Update utils for LayoutAttr and SliceAttr support (#154819)
1 parent cf0f7f6 commit c96e2cd

File tree

10 files changed

+235
-135
lines changed

10 files changed

+235
-135
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
185185
InterfaceMethod<"Check the availability of workgroup level layouts",
186186
"bool",
187187
"isForWorkgroup">,
188+
InterfaceMethod<"Check the availability of subgroup level layouts",
189+
"bool",
190+
"isForSubgroup">,
188191
InterfaceMethod<"Get the rank of attribute",
189192
"int64_t",
190193
"getRank">,
@@ -197,14 +200,26 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
197200
return 0;
198201
}], [{}]>,
199202
InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
200-
"std::optional<SmallVector<int64_t>>",
203+
"SmallVector<int64_t>",
201204
"getSgLayoutAsInt">,
202205
InterfaceMethod<"Get the SgData field of the attribute as integer array",
203-
"std::optional<SmallVector<int64_t>>",
206+
"SmallVector<int64_t>",
204207
"getSgDataAsInt">,
208+
InterfaceMethod<"Get the InstData field of the attribute as integer array",
209+
"SmallVector<int64_t>",
210+
"getInstDataAsInt">,
211+
InterfaceMethod<"Get the LaneLayout field of the attribute as integer array",
212+
"SmallVector<int64_t>",
213+
"getLaneLayoutAsInt">,
214+
InterfaceMethod<"Get the LaneData field of the attribute as integer array",
215+
"SmallVector<int64_t>",
216+
"getLaneDataAsInt">,
205217
InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
206218
"xegpu::DistributeLayoutAttr",
207219
"dropSgLayoutAndData">,
220+
InterfaceMethod<"Derive a new layout by dropping InstData",
221+
"xegpu::DistributeLayoutAttr",
222+
"dropInstData">,
208223
InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
209224
indices based on the effective subgroup layout.}],
210225
"FailureOr<SmallVector<Value>>",
@@ -376,16 +391,34 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
376391
getLaneLayout(), getLaneData(), getOrder());
377392
}
378393

379-
std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {
394+
SmallVector<int64_t> getSgLayoutAsInt() const {
380395
if (DenseI32ArrayAttr layout = getSgLayout())
381396
return llvm::to_vector_of<int64_t>(layout.asArrayRef());
382-
return std::nullopt;
397+
return {};
383398
}
384399

385-
std::optional<SmallVector<int64_t>> getSgDataAsInt() const {
400+
SmallVector<int64_t> getSgDataAsInt() const {
386401
if (DenseI32ArrayAttr data = getSgData())
387402
return llvm::to_vector_of<int64_t>(data.asArrayRef());
388-
return std::nullopt;
403+
return {};
404+
}
405+
406+
SmallVector<int64_t> getInstDataAsInt() const {
407+
if (DenseI32ArrayAttr inst = getInstData())
408+
return llvm::to_vector_of<int64_t>(inst.asArrayRef());
409+
return {};
410+
}
411+
412+
SmallVector<int64_t> getLaneLayoutAsInt() const {
413+
if (DenseI32ArrayAttr layout = getLaneLayout())
414+
return llvm::to_vector_of<int64_t>(layout.asArrayRef());
415+
return {};
416+
}
417+
418+
SmallVector<int64_t> getLaneDataAsInt() const {
419+
if (DenseI32ArrayAttr data = getLaneData())
420+
return llvm::to_vector_of<int64_t>(data.asArrayRef());
421+
return {};
389422
}
390423

391424
/// Delinearizes a linear subgroup ID into its multidimensional indices
@@ -466,26 +499,67 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
466499

467500
/// Returns the SgLayout of the attribute, computed by applying
468501
/// the slice dimensions to the underlying LayoutAttr.
469-
std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {
502+
SmallVector<int64_t> getSgLayoutAsInt() const {
470503
SliceAttr attr = flatten();
471504
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
472-
if (auto layout = parent.getSgLayoutAsInt()) {
505+
auto layout = parent.getSgLayoutAsInt();
506+
if (layout.size()) {
473507
ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
474-
return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(*layout), dims);
508+
return XeGPUDialect::slice(ArrayRef<int64_t>(layout), dims);
475509
}
476-
return std::nullopt;
510+
return {};
477511
}
478512

479513
/// Returns the SgData of the attribute, computed by applying
480514
/// the slice dimensions to the underlying LayoutAttr.
481-
std::optional<SmallVector<int64_t>> getSgDataAsInt() const {
515+
SmallVector<int64_t> getSgDataAsInt() const {
516+
SliceAttr attr = flatten();
517+
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
518+
auto data = parent.getSgDataAsInt();
519+
if (data.size()) {
520+
ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
521+
return XeGPUDialect::slice(ArrayRef<int64_t>(data), dims);
522+
}
523+
return {};
524+
}
525+
526+
/// Returns the InstData of the attribute, computed by applying
527+
/// the slice dimensions to the underlying LayoutAttr.
528+
SmallVector<int64_t> getInstDataAsInt() const {
529+
SliceAttr attr = flatten();
530+
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
531+
auto inst = parent.getInstDataAsInt();
532+
if (inst.size()) {
533+
ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
534+
return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(inst), dims);
535+
}
536+
return {};
537+
}
538+
539+
/// Returns the LaneLayout of the attribute, computed by applying
540+
/// the slice dimensions to the underlying LayoutAttr.
541+
SmallVector<int64_t> getLaneLayoutAsInt() const {
542+
SliceAttr attr = flatten();
543+
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
544+
auto layout = parent.getLaneLayoutAsInt();
545+
if (layout.size()) {
546+
ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
547+
return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(layout), dims);
548+
}
549+
return {};
550+
}
551+
552+
/// Returns the LaneData of the attribute, computed by applying
553+
/// the slice dimensions to the underlying LayoutAttr.
554+
SmallVector<int64_t> getLaneDataAsInt() const {
482555
SliceAttr attr = flatten();
483556
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
484-
if (auto data = parent.getSgDataAsInt()) {
557+
auto data = parent.getLaneDataAsInt();
558+
if (data.size()) {
485559
ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
486-
return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(*data), dims);
560+
return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(data), dims);
487561
}
488-
return std::nullopt;
562+
return {};
489563
}
490564

491565
SliceAttr dropSgLayoutAndData() {

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def XeGPU_Dialect : Dialect {
4040
let extraClassDeclaration = [{
4141
/// Checks if the given shape can be evenly distributed based on the layout
4242
/// and data factors provided by the LayoutAttr.
43-
static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::LayoutAttr attr);
43+
static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::DistributeLayoutAttr attr);
4444

4545
/// drops/slices the shape in the specified dims, and return the rest. e.g.,
4646
/// for shape = [32, 64, 8], dims = [0, 2], it will return [64]

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,8 +1242,8 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou
12421242
the IR is lowered to WI level because that is the end result of all distributions.
12431243
}];
12441244
let arguments = (ins XeGPU_VectorType: $source,
1245-
XeGPU_LayoutAttr: $input_layout,
1246-
XeGPU_LayoutAttr: $target_layout);
1245+
DistributeLayoutAttr: $input_layout,
1246+
DistributeLayoutAttr: $target_layout);
12471247
let results = (outs XeGPU_VectorType: $result);
12481248
let assemblyFormat = [{
12491249
$source prop-dict attr-dict `:` type($source)

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
1010
#define MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
1111

12+
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1213
#include "mlir/IR/BuiltinTypes.h"
1314
#include "mlir/IR/OpDefinition.h"
1415
namespace mlir {
@@ -21,6 +22,7 @@ class ValueRange;
2122
class TypeConverter;
2223

2324
namespace xegpu {
25+
class DistributeLayoutAttr;
2426
class LayoutAttr;
2527
class TensorDescType;
2628
} // namespace xegpu
@@ -60,46 +62,58 @@ FailureOr<VectorType> getDistributedVectorType(xegpu::TensorDescType tdescTy);
6062
FailureOr<VectorType> getDistributedVectorType(VectorType originalType,
6163
LayoutAttr layout);
6264

63-
/// Return the attribute name for the OpOperand to attach LayoutAttr
65+
/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
6466
std::string getLayoutName(const OpOperand &operand);
6567

66-
/// Return the attribute name for the OpResult to attach LayoutAttr
68+
/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
6769
std::string getLayoutName(const OpResult result);
6870

69-
/// Retrieves the LayoutAttr associated with a given Value. For TensorDescType
70-
/// values, the LayoutAttr is extracted from the TensorDescType itself. For
71-
/// other values, it is obtained from the attributes of the defining operation.
72-
/// Returns nullptr if no LayoutAttr is found.
73-
LayoutAttr getLayoutAttr(const Value value);
71+
/// Retrieves the DistributeLayoutAttr associated with a given Value. For
72+
/// TensorDescType values, the DistributeLayoutAttr is extracted from the
73+
/// TensorDescType itself. For other values, it is obtained from the attributes
74+
/// of the defining operation. Returns nullptr if no DistributeLayoutAttr is
75+
/// found.
76+
DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
7477

75-
/// Retrieves the LayoutAttr associated with a given OpOperand. It will
76-
/// first check the operand_layout_{id} of the owner operation. If not found,
77-
/// it will check the operand itself and its defining op.
78-
LayoutAttr getLayoutAttr(const OpOperand &opr);
78+
template <typename AttrTy>
79+
AttrTy getDistributeLayoutAttrOfType(const Value value) {
80+
return dyn_cast_if_present<AttrTy>(getDistributeLayoutAttr(value));
81+
}
82+
83+
/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It
84+
/// will first check the operand_layout_{id} of the owner operation. If not
85+
/// found, it will check the operand itself and its defining op.
86+
DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
87+
88+
template <typename AttrTy>
89+
AttrTy getDistributeLayoutAttrOfType(const OpOperand &opr) {
90+
return dyn_cast_if_present<AttrTy>(getDistributeLayoutAttr(opr));
91+
}
7992

8093
/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
8194
template <typename T,
8295
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
8396
std::is_same_v<T, OpResult>>>
8497
void removeLayoutAttr(const T &operandOrResult);
8598

86-
/// Removes the LayoutAttr for each OpOperand and OpResult of the given
87-
/// operation if they exist. If the operation contains regions, it is also
99+
/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the
100+
/// given operation if they exist. If the operation contains regions, it is also
88101
/// applied recursively to the contained operations
89102
void removeLayoutAttrs(Operation *op);
90103

91-
/// Sets the LayoutAttr for a given OpOperand or OpResult by attaching
104+
/// Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching
92105
/// it to the owner's dictionary attributes
93106
template <typename T,
94107
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
95108
std::is_same_v<T, OpResult>>>
96-
void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout);
97-
98-
/// Set the LayoutAttr for each OpOperand and OpResult of the given operation.
99-
/// If the operation contains regions, it is also applied recursively to the
100-
/// contained operations
101-
void setLayoutAttrs(Operation *op,
102-
function_ref<LayoutAttr(Value)> getLayoutImpl);
109+
void setDistributeLayoutAttr(const T &operandOrResult,
110+
const DistributeLayoutAttr layout);
111+
112+
/// Set the DistributeLayoutAttr for each OpOperand and OpResult of the given
113+
/// operation. If the operation contains regions, it is also applied recursively
114+
/// to the contained operations
115+
void setDistributeLayoutAttrs(
116+
Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl);
103117

104118
/// Extract a set of small vectors from a value with a given shape using
105119
/// vector.extract_stride_slice

0 commit comments

Comments
 (0)