Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b9a6d98
init code
Jianhui-Li Jul 1, 2025
2465050
add tests
Jianhui-Li Jul 2, 2025
1077871
git-clang-format
Jianhui-Li Jul 2, 2025
42baa22
add more tests
Jianhui-Li Jul 2, 2025
204d347
git-clang-format
Jianhui-Li Jul 2, 2025
2793c81
add ui64 case support
Jianhui-Li Jul 12, 2025
f23ea03
modify ui64 test
Jianhui-Li Jul 12, 2025
0bb958b
Merge branch 'main' into dialect-assembly-format
Jianhui-Li Jul 12, 2025
6793689
remove unnecessary comments
Jianhui-Li Jul 12, 2025
4a96c71
fix VectorToXeGPU tests
Jianhui-Li Jul 14, 2025
689a8a5
tweak default offset value
Jianhui-Li Jul 14, 2025
02d3795
git-clang-format
Jianhui-Li Jul 14, 2025
01718f4
add builders
Jianhui-Li Jul 15, 2025
5ef6ca9
git-clang-format
Jianhui-Li Jul 15, 2025
26a222d
Merge branch 'main' into dialect-assembly-format
Jianhui-Li Jul 15, 2025
882313f
simplify custom parser
Jianhui-Li Jul 15, 2025
456534a
add comma before shape and strides
Jianhui-Li Jul 15, 2025
b6f016e
tie the offsets rank to input tensor shape instead of tdesc
Jianhui-Li Jul 15, 2025
cd518d2
git-clang-format
Jianhui-Li Jul 15, 2025
546a3f7
addverifier for invalid cases
Jianhui-Li Jul 15, 2025
7846955
git-clang-format
Jianhui-Li Jul 16, 2025
ded9552
add comments
Jianhui-Li Jul 16, 2025
97b6e39
simplify custom print
Jianhui-Li Jul 17, 2025
ed1d48e
git-clang-format
Jianhui-Li Jul 17, 2025
b3edff6
Merge branch 'main' into dialect-assembly-format
Jianhui-Li Jul 17, 2025
d3e935b
use simpler interface for DenseI64ArrayAttr
Jianhui-Li Jul 17, 2025
205fea7
address feedback
Jianhui-Li Jul 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,27 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
Variadic<Index>: $offsets,
Variadic<Index>: $shape,
Variadic<Index>: $strides,
DenseI64ArrayAttr: $const_offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_shape,
OptionalAttr<DenseI64ArrayAttr>: $const_strides
);
let results = (outs XeGPU_TensorDesc: $TensorDesc);

let assemblyFormat = [{
$source ``
custom<DynamicIndexList>($offsets, $const_offsets)
(`,` custom<DynamicIndexList>($shape, $const_shape)^
`,` custom<DynamicIndexList>($strides, $const_strides))?
attr-dict `:` type($source) `->` qualified(type($TensorDesc))
}];

let hasVerifier = 1;

let hasCustomAssemblyFormat = 1;

let builders = [
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,

OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
"llvm::ArrayRef<OpFoldResult>": $shape,
"llvm::ArrayRef<OpFoldResult>": $strides)>,

OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
"llvm::ArrayRef<OpFoldResult>": $shape,
"llvm::ArrayRef<OpFoldResult>": $strides)>,

OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
"llvm::ArrayRef<OpFoldResult>": $offsets)>,

Expand Down Expand Up @@ -163,9 +167,30 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
}

ArrayRef<int64_t> getStaticOffsets(){
return getConstOffsets();
auto attr = getConstOffsetsAttr();

if (attr)
return attr;

auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
int rank = 0;
if (memrefType) {
//use source memref's rank, as source memref rank may be higher
rank = memrefType.getRank();
} else {
//nd_tdesc created from ui64, use nd_tdesc's rank
rank = getTensorDescShape().size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not safe to infer the rank from the TensorDesc, since TensorDesc could have fewer rank than offset. You can simply use int rank = getStaticSizes().size() instead;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, user doesn't specify neither const offsets or dynamic offset values. So I assume that we can only infer the rank from TensorDesc. Not sure getStaticSizes() can give us correct result.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need of brackets for single line if and else.


// The offsets are allowed to be empty. The Traits verification of OffsetSizeAndStrideOpInterface interface assumes offsets being present.
// It is set to be MAX to indicate user not passed any value, instead of kDynamic which means offsets passed as value.
setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, std::numeric_limits<int64_t>::max()));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably we need to reuse this constant in future. Better to define it somewhere.

static constexpr int64_t optionalValue = std::numeric_limits<int64_t>::max();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parser and print code is supposed to removed once we finish the transition that move the offsets from create_nd_tdesc definition to load_nd. So no plan to reuse.

attr = getConstOffsetsAttr();
return attr;
}


/// wrapper for matching with OffsetSizeAndStrideOpInterface
/// If source is IntegerType or `const_shape` is filled,
/// it will return `const_shape`, such that mixes of `shape`
Expand Down
302 changes: 300 additions & 2 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,64 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//

void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
Type tdesc, TypedValue<MemRefType> source) {
[[maybe_unused]] auto ty = source.getType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think [[maybe_unused]] is not needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ty is only used in the assert statement which is unused in release binary.

assert(ty.hasStaticShape());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some error text in assert. like "expecting a memref with static shape"


build(builder, state, tdesc, source, ValueRange({}) /* dynamic offsets */,
ValueRange({}) /* empty dynamic shape */,
ValueRange({}) /* empty dynamic strides */,
builder.getDenseI64ArrayAttr({}) /* const offsets */,
builder.getDenseI64ArrayAttr({}) /* empty const shape*/,
builder.getDenseI64ArrayAttr({}) /* empty const strides*/);
}

void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
Type tdesc, TypedValue<MemRefType> source,
llvm::ArrayRef<OpFoldResult> shape,
llvm::ArrayRef<OpFoldResult> strides) {
assert(shape.size() && strides.size() && shape.size() == strides.size());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some error text. why this invariant must be satisfied.


llvm::SmallVector<int64_t> staticShape;
llvm::SmallVector<int64_t> staticStrides;
llvm::SmallVector<Value> dynamicShape;
llvm::SmallVector<Value> dynamicStrides;

dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);

auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);

build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
staticStridesAttr);
}

void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
Type tdesc, TypedValue<IntegerType> source,
llvm::ArrayRef<OpFoldResult> shape,
llvm::ArrayRef<OpFoldResult> strides) {
assert(shape.size() && strides.size() && shape.size() == strides.size());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here.


llvm::SmallVector<int64_t> staticShape;
llvm::SmallVector<int64_t> staticStrides;
llvm::SmallVector<Value> dynamicShape;
llvm::SmallVector<Value> dynamicStrides;

dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);

auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);

build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
staticStridesAttr);
}

void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
Type tdesc, TypedValue<MemRefType> source,
llvm::ArrayRef<OpFoldResult> offsets) {
Expand All @@ -125,8 +183,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
ValueRange({}) /* empty dynamic shape */,
ValueRange({}) /* empty dynamic strides */,
staticOffsets /* const offsets */, {} /* empty const shape*/,
{} /* empty const strides*/);
builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
{} /* empty const shape*/, {} /* empty const strides*/);
}

void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is the new build methods implemented?

    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,

    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
                   "llvm::ArrayRef<OpFoldResult>": $shape,
                   "llvm::ArrayRef<OpFoldResult>": $strides)>,

    OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
                   "llvm::ArrayRef<OpFoldResult>": $shape,
                   "llvm::ArrayRef<OpFoldResult>": $strides)>,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Expand Down Expand Up @@ -221,6 +279,246 @@ LogicalResult CreateNdDescOp::verify() {
return success();
}

ParseResult parseOptionalDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
SmallVectorImpl<Type> *valueTypes = nullptr,
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we can not reuse parseDynamicIndexList method and avoid this? I see lost of logic replicated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The offset is provided as optional bracket [], so we need to customize parseDynamicIndexList.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored and simplified the custom parser.


SmallVector<int64_t, 4> integerVals;
SmallVector<bool, 4> scalableVals;
auto parseIntegerOrValue = [&]() {
OpAsmParser::UnresolvedOperand operand;
auto res = parser.parseOptionalOperand(operand);

// When encountering `[`, assume that this is a scalable index.
scalableVals.push_back(parser.parseOptionalLSquare().succeeded());

if (res.has_value() && succeeded(res.value())) {
values.push_back(operand);
integerVals.push_back(ShapedType::kDynamic);
if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
return failure();
} else {
int64_t integer;
if (failed(parser.parseInteger(integer)))
return failure();
integerVals.push_back(integer);
}

// If this is assumed to be a scalable index, verify that there's a closing
// `]`.
if (scalableVals.back() && parser.parseOptionalRSquare().failed())
return failure();
return success();
};
if (parser.parseOptionalLSquare().succeeded()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that for no-offset case this check will fail?
Example:

create_nd %src shape: [] strides: []

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a comment here like "If the optional values are given there must be left bracket"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Added

if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
parser.parseRSquare())
return parser.emitError(parser.getNameLoc())
<< "expected SSA value or integer";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
<< "expected SSA value or integer";
<< "expected a list of SSA values or integers";

integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
return success();
}
return success();
}

::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
::mlir::OpAsmParser::UnresolvedOperand sourceRawOperand{};
::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> sourceOperands(
&sourceRawOperand, 1);
::llvm::SMLoc sourceOperandsLoc;

::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
offsetsOperands;
::llvm::SMLoc offsetsOperandsLoc;
::mlir::DenseI64ArrayAttr const_offsetsAttr;
::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> shapeOperands;
::llvm::SMLoc shapeOperandsLoc;
::mlir::DenseI64ArrayAttr const_shapeAttr;
::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
stridesOperands;
::llvm::SMLoc stridesOperandsLoc;
::mlir::DenseI64ArrayAttr const_stridesAttr;
::mlir::Type sourceRawType{};
::llvm::ArrayRef<::mlir::Type> sourceTypes(&sourceRawType, 1);
::mlir::Type TensorDescRawType{};
::llvm::ArrayRef<::mlir::Type> TensorDescTypes(&TensorDescRawType, 1);

sourceOperandsLoc = parser.getCurrentLocation();
if (parser.parseOperand(sourceRawOperand))
return ::mlir::failure();

offsetsOperandsLoc = parser.getCurrentLocation();

DenseBoolArrayAttr scalableFlags;
auto odsResult = parseOptionalDynamicIndexList(
parser, offsetsOperands, const_offsetsAttr, scalableFlags);

if (const_offsetsAttr) {
if (odsResult)
return ::mlir::failure();
result.getOrAddProperties<CreateNdDescOp::Properties>().const_offsets =
const_offsetsAttr;
}

if (::mlir::succeeded(parser.parseOptionalKeyword("shape"))) {
if (parser.parseColon())
return ::mlir::failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't it be nicer to use a keyword for offset as well? For the optional case it will be empty square brackets.

offsets : [], strides : [...], shapes: []

Copy link
Contributor Author

@Jianhui-Li Jianhui-Li Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are moving offsets to load_nd.

{
shapeOperandsLoc = parser.getCurrentLocation();
auto odsResult =
parseDynamicIndexList(parser, shapeOperands, const_shapeAttr);
if (const_shapeAttr) {
if (odsResult)
return ::mlir::failure();
result.getOrAddProperties<CreateNdDescOp::Properties>().const_shape =
const_shapeAttr;
}
}

if (parser.parseKeyword("strides"))
return ::mlir::failure();
if (parser.parseColon())
return ::mlir::failure();
{
stridesOperandsLoc = parser.getCurrentLocation();
auto odsResult =
parseDynamicIndexList(parser, stridesOperands, const_stridesAttr);
if (const_stridesAttr) {
if (odsResult)
return ::mlir::failure();
result.getOrAddProperties<CreateNdDescOp::Properties>().const_strides =
const_stridesAttr;
}
}
}
{
auto loc = parser.getCurrentLocation();
if (parser.parseOptionalAttrDict(result.attributes))
return ::mlir::failure();
if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
return parser.emitError(loc)
<< "'" << result.name.getStringRef() << "' op ";
})))
return ::mlir::failure();
}
if (parser.parseColon())
return ::mlir::failure();

{
::mlir::Type type;
if (parser.parseCustomTypeWithFallback(type))
return ::mlir::failure();
sourceRawType = type;
}
if (parser.parseArrow())
return ::mlir::failure();

if (parser.parseType(TensorDescRawType))
return ::mlir::failure();

::llvm::copy(::llvm::ArrayRef<int32_t>(
{1, static_cast<int32_t>(offsetsOperands.size()),
static_cast<int32_t>(shapeOperands.size()),
static_cast<int32_t>(stridesOperands.size())}),
result.getOrAddProperties<CreateNdDescOp::Properties>()
.operandSegmentSizes.begin());

::mlir::Type odsBuildableType0 = parser.getBuilder().getIndexType();
result.addTypes(TensorDescTypes);

if (parser.resolveOperands(sourceOperands, sourceTypes, sourceOperandsLoc,
result.operands))
return ::mlir::failure();

if (parser.resolveOperands(offsetsOperands, odsBuildableType0,
offsetsOperandsLoc, result.operands))
return ::mlir::failure();

if (parser.resolveOperands(shapeOperands, odsBuildableType0, shapeOperandsLoc,
result.operands))
return ::mlir::failure();

if (parser.resolveOperands(stridesOperands, odsBuildableType0,
stridesOperandsLoc, result.operands))
return ::mlir::failure();
return ::mlir::success();
}

void CreateNdDescOp::print(::mlir::OpAsmPrinter &_odsPrinter) {
_odsPrinter << ' ';
_odsPrinter << getSource();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should avoid using tablegen generate variable names _odsPrinter and use something more readable like printer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't expect this code will stay permanent. Keeping them same as Talegen generated printer code helps debugging.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.


auto constOffsetsAttr = getConstOffsetsAttr();
bool printOffsets = false;
if (constOffsetsAttr && constOffsetsAttr.size() > 0) {
auto firstVal = constOffsetsAttr.asArrayRef()[0];
if (firstVal != std::numeric_limits<int64_t>::max()) {
printOffsets = true;
}
}
if (printOffsets) {

printDynamicIndexList(_odsPrinter, *this, getOffsets(),
getConstOffsetsAttr());
}
if (((!getShape().empty()) || (getConstShapeAttr()))) {
_odsPrinter << ' ' << "shape";
_odsPrinter << ' ' << ":";
_odsPrinter << ' ';
printDynamicIndexList(_odsPrinter, *this, getShape(), getConstShapeAttr());
_odsPrinter << ' ' << "strides";
_odsPrinter << ' ' << ":";
_odsPrinter << ' ';
printDynamicIndexList(_odsPrinter, *this, getStrides(),
getConstStridesAttr());
}
::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
elidedAttrs.push_back("operandSegmentSizes");
elidedAttrs.push_back("const_offsets");
elidedAttrs.push_back("const_shape");
elidedAttrs.push_back("const_strides");
_odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
_odsPrinter << ' ' << ":";
_odsPrinter << ' ';
{
auto type = getSource().getType();
if (auto validType = ::llvm::dyn_cast<::mlir::Type>(type))
_odsPrinter.printStrippedAttrOrType(validType);
else
_odsPrinter << type;
}
_odsPrinter << ' ' << "->";
_odsPrinter << ' ';
// _odsPrinter << getTensorDesc().getType();

_odsPrinter << "!xegpu.tensor_desc<";

auto tDesc = getTensorDesc().getType();
auto shape = tDesc.getShape();
for (int64_t dim : shape) {
if (mlir::ShapedType::isDynamic(dim))
_odsPrinter << '?';
else
_odsPrinter << dim;
_odsPrinter << 'x';
}

_odsPrinter << tDesc.getElementType();

if (auto encoding = tDesc.getEncoding())
_odsPrinter << ", " << encoding;

if (auto layout = tDesc.getLayout())
_odsPrinter << ", " << layout;

_odsPrinter << ">";
}

//===----------------------------------------------------------------------===//
// XeGPU_PrefetchNdOp
//===----------------------------------------------------------------------===//
Expand Down
Loading