-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR][XeGPU] make offsets optional for create_nd_tdesc #148335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
b9a6d98
2465050
1077871
42baa22
204d347
2793c81
f23ea03
0bb958b
6793689
4a96c71
689a8a5
02d3795
01718f4
5ef6ca9
26a222d
882313f
456534a
b6f016e
cd518d2
546a3f7
7846955
ded9552
97b6e39
ed1d48e
b3edff6
d3e935b
205fea7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -110,23 +110,27 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface | |
Variadic<Index>: $offsets, | ||
Variadic<Index>: $shape, | ||
Variadic<Index>: $strides, | ||
DenseI64ArrayAttr: $const_offsets, | ||
OptionalAttr<DenseI64ArrayAttr>: $const_offsets, | ||
OptionalAttr<DenseI64ArrayAttr>: $const_shape, | ||
OptionalAttr<DenseI64ArrayAttr>: $const_strides | ||
); | ||
let results = (outs XeGPU_TensorDesc: $TensorDesc); | ||
|
||
let assemblyFormat = [{ | ||
$source `` | ||
custom<DynamicIndexList>($offsets, $const_offsets) | ||
(`,` custom<DynamicIndexList>($shape, $const_shape)^ | ||
`,` custom<DynamicIndexList>($strides, $const_strides))? | ||
attr-dict `:` type($source) `->` qualified(type($TensorDesc)) | ||
}]; | ||
|
||
let hasVerifier = 1; | ||
|
||
let hasCustomAssemblyFormat = 1; | ||
|
||
let builders = [ | ||
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>, | ||
|
||
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source, | ||
"llvm::ArrayRef<OpFoldResult>": $shape, | ||
"llvm::ArrayRef<OpFoldResult>": $strides)>, | ||
|
||
OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source, | ||
"llvm::ArrayRef<OpFoldResult>": $shape, | ||
"llvm::ArrayRef<OpFoldResult>": $strides)>, | ||
|
||
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source, | ||
"llvm::ArrayRef<OpFoldResult>": $offsets)>, | ||
|
||
|
@@ -163,9 +167,30 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface | |
} | ||
|
||
ArrayRef<int64_t> getStaticOffsets(){ | ||
return getConstOffsets(); | ||
auto attr = getConstOffsetsAttr(); | ||
|
||
if (attr) | ||
return attr; | ||
|
||
auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType()); | ||
int rank = 0; | ||
if (memrefType) { | ||
//use source memref's rank, as source memref rank may be higher | ||
rank = memrefType.getRank(); | ||
} else { | ||
//nd_tdesc created from ui64, use nd_tdesc's rank | ||
rank = getTensorDescShape().size(); | ||
}; | ||
|
||
|
||
// The offsets are allowed to be empty. The Traits verification of OffsetSizeAndStrideOpInterface interface assumes offsets being present. | ||
// It is set to be MAX to indicate user not passed any value, instead of kDynamic which means offsets passed as value. | ||
setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, std::numeric_limits<int64_t>::max())); | ||
|
||
|
||
attr = getConstOffsetsAttr(); | ||
return attr; | ||
} | ||
|
||
|
||
/// wrapper for matching with OffsetSizeAndStrideOpInterface | ||
/// If source is IntegerType or `const_shape` is filled, | ||
/// it will return `const_shape`, such that mixes of `shape` | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -112,6 +112,64 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, | |||||
//===----------------------------------------------------------------------===// | ||||||
// XeGPU_CreateNdDescOp | ||||||
//===----------------------------------------------------------------------===// | ||||||
|
||||||
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | ||||||
Type tdesc, TypedValue<MemRefType> source) { | ||||||
[[maybe_unused]] auto ty = source.getType(); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ty is only used in the assert statement which is unused in release binary. |
||||||
assert(ty.hasStaticShape()); | ||||||
|
||||||
|
||||||
build(builder, state, tdesc, source, ValueRange({}) /* dynamic offsets */, | ||||||
ValueRange({}) /* empty dynamic shape */, | ||||||
ValueRange({}) /* empty dynamic strides */, | ||||||
builder.getDenseI64ArrayAttr({}) /* const offsets */, | ||||||
adam-smnk marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
builder.getDenseI64ArrayAttr({}) /* empty const shape*/, | ||||||
builder.getDenseI64ArrayAttr({}) /* empty const strides*/); | ||||||
} | ||||||
|
||||||
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | ||||||
Type tdesc, TypedValue<MemRefType> source, | ||||||
llvm::ArrayRef<OpFoldResult> shape, | ||||||
llvm::ArrayRef<OpFoldResult> strides) { | ||||||
assert(shape.size() && strides.size() && shape.size() == strides.size()); | ||||||
|
||||||
|
||||||
llvm::SmallVector<int64_t> staticShape; | ||||||
llvm::SmallVector<int64_t> staticStrides; | ||||||
llvm::SmallVector<Value> dynamicShape; | ||||||
llvm::SmallVector<Value> dynamicStrides; | ||||||
|
||||||
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); | ||||||
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); | ||||||
|
||||||
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); | ||||||
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); | ||||||
|
||||||
build(builder, state, tdesc, source, ValueRange({}), dynamicShape, | ||||||
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, | ||||||
staticStridesAttr); | ||||||
} | ||||||
|
||||||
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | ||||||
Type tdesc, TypedValue<IntegerType> source, | ||||||
llvm::ArrayRef<OpFoldResult> shape, | ||||||
llvm::ArrayRef<OpFoldResult> strides) { | ||||||
assert(shape.size() && strides.size() && shape.size() == strides.size()); | ||||||
|
||||||
|
||||||
llvm::SmallVector<int64_t> staticShape; | ||||||
llvm::SmallVector<int64_t> staticStrides; | ||||||
llvm::SmallVector<Value> dynamicShape; | ||||||
llvm::SmallVector<Value> dynamicStrides; | ||||||
|
||||||
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); | ||||||
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); | ||||||
|
||||||
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); | ||||||
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); | ||||||
|
||||||
build(builder, state, tdesc, source, ValueRange({}), dynamicShape, | ||||||
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, | ||||||
staticStridesAttr); | ||||||
} | ||||||
|
||||||
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | ||||||
Type tdesc, TypedValue<MemRefType> source, | ||||||
llvm::ArrayRef<OpFoldResult> offsets) { | ||||||
|
@@ -125,8 +183,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | |||||
build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */, | ||||||
ValueRange({}) /* empty dynamic shape */, | ||||||
ValueRange({}) /* empty dynamic strides */, | ||||||
staticOffsets /* const offsets */, {} /* empty const shape*/, | ||||||
{} /* empty const strides*/); | ||||||
builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */, | ||||||
{} /* empty const shape*/, {} /* empty const strides*/); | ||||||
} | ||||||
|
||||||
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is the new build methods implemented?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||||||
|
@@ -221,6 +279,246 @@ LogicalResult CreateNdDescOp::verify() { | |||||
return success(); | ||||||
} | ||||||
|
||||||
ParseResult parseOptionalDynamicIndexList( | ||||||
OpAsmParser &parser, | ||||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, | ||||||
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, | ||||||
SmallVectorImpl<Type> *valueTypes = nullptr, | ||||||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we can not reuse There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The offset is provided as optional bracket [], so we need to customize parseDynamicIndexList. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. refactored and simplified the custom parser. |
||||||
|
||||||
SmallVector<int64_t, 4> integerVals; | ||||||
SmallVector<bool, 4> scalableVals; | ||||||
auto parseIntegerOrValue = [&]() { | ||||||
OpAsmParser::UnresolvedOperand operand; | ||||||
auto res = parser.parseOptionalOperand(operand); | ||||||
|
||||||
// When encountering `[`, assume that this is a scalable index. | ||||||
scalableVals.push_back(parser.parseOptionalLSquare().succeeded()); | ||||||
|
||||||
if (res.has_value() && succeeded(res.value())) { | ||||||
values.push_back(operand); | ||||||
integerVals.push_back(ShapedType::kDynamic); | ||||||
if (valueTypes && parser.parseColonType(valueTypes->emplace_back())) | ||||||
return failure(); | ||||||
} else { | ||||||
int64_t integer; | ||||||
if (failed(parser.parseInteger(integer))) | ||||||
return failure(); | ||||||
integerVals.push_back(integer); | ||||||
} | ||||||
|
||||||
// If this is assumed to be a scalable index, verify that there's a closing | ||||||
// `]`. | ||||||
if (scalableVals.back() && parser.parseOptionalRSquare().failed()) | ||||||
return failure(); | ||||||
return success(); | ||||||
}; | ||||||
if (parser.parseOptionalLSquare().succeeded()) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume that for no-offset case this check will fail?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add a comment here like "If the optional values are given there must be left bracket" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. Added |
||||||
if (parser.parseCommaSeparatedList(parseIntegerOrValue) || | ||||||
parser.parseRSquare()) | ||||||
return parser.emitError(parser.getNameLoc()) | ||||||
<< "expected SSA value or integer"; | ||||||
|
<< "expected SSA value or integer"; | |
<< "expected a list of SSA values or integers"; |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't it be nicer to use a keyword for offset as well? For the optional case it will be empty square brackets.
offsets : [], strides : [...], shapes: []
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are moving offsets to load_nd.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should avoid using tablegen generate variable names _odsPrinter
and use something more readable like printer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't expect this code will stay permanent. Keeping them same as Talegen generated printer code helps debugging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not safe to infer the rank from the TensorDesc, since TensorDesc could have fewer rank than offset. You can simply use
int rank = getStaticSizes().size()
instead;There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, user doesn't specify neither const offsets or dynamic offset values. So I assume that we can only infer the rank from TensorDesc. Not sure getStaticSizes() can give us correct result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed