Skip to content

Commit dbc1c02

Browse files
committed
Changes:
* Use OptionalTypesMatchWith * Add constraint to verify both padding and mask are specified, as well as test.
1 parent e217711 commit dbc1c02

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -233,21 +233,23 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
233233

234234
def TileLoadOp : ArmSME_Op<"tile_load", [
235235
AttrSizedOperandSegments,
236-
TypesMatchWith<
237-
"padding type matches element type of result (if present)",
236+
OptionalTypesMatchWith<
237+
"padding type matches element type of result",
238238
"result", "padding",
239-
"::llvm::cast<VectorType>($_self).getElementType()",
240-
"!getPadding() || std::equal_to<>()"
239+
"::llvm::cast<VectorType>($_self).getElementType()"
241240
>,
242-
TypesMatchWith<
243-
"mask has i1 element type and same shape as result (if present)",
241+
OptionalTypesMatchWith<
242+
"mask has i1 element type and same shape as result",
244243
"result", "mask",
245244
"VectorType("
246245
"VectorType::Builder("
247246
"::llvm::cast<mlir::VectorType>($_self)"
248-
").setElementType(IntegerType::get($_self.getContext(), 1)))",
249-
"!getMask() || std::equal_to<>()"
250-
>
247+
").setElementType(IntegerType::get($_self.getContext(), 1)))"
248+
>,
249+
PredOpTrait<
250+
"both `padding` and `mask` should be provided or neither",
251+
CPred<"bool(getPadding()) == bool(getMask())">
252+
>,
251253
]> {
252254
let summary = "Tile load operation";
253255
let description = [{

mlir/test/Dialect/ArmSME/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,12 @@ func.func @arm_sme_tile_load__bad_mask_type(%src : memref<?x?xf64>, %pad : f64,
141141
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
142142
return
143143
}
144+
145+
// -----
146+
147+
func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64) {
148+
%c0 = arith.constant 0 : index
149+
// expected-error@+1 {{op failed to verify that both `padding` and `mask` should be provided or neither}}
150+
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, : memref<?x?xf64>, vector<[2]x[2]xf64>
151+
return
152+
}

0 commit comments

Comments
 (0)