@@ -231,7 +231,24 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
231
231
let assemblyFormat = "attr-dict `:` type($res)";
232
232
}
233
233
234
- def TileLoadOp : ArmSME_Op<"tile_load"> {
234
+ def TileLoadOp : ArmSME_Op<"tile_load", [
235
+ AttrSizedOperandSegments,
236
+ TypesMatchWith<
237
+ "padding type matches element type of result (if present)",
238
+ "result", "padding",
239
+ "::llvm::cast<VectorType>($_self).getElementType()",
240
+ "!getPadding() || std::equal_to<>()"
241
+ >,
242
+ TypesMatchWith<
243
+ "mask has i1 element type and same shape as result (if present)",
244
+ "result", "mask",
245
+ "VectorType("
246
+ "VectorType::Builder("
247
+ "::llvm::cast<mlir::VectorType>($_self)"
248
+ ").setElementType(IntegerType::get($_self.getContext(), 1)))",
249
+ "!getMask() || std::equal_to<>()"
250
+ >
251
+ ]> {
235
252
let summary = "Tile load operation";
236
253
let description = [{
237
254
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
@@ -242,6 +259,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
242
259
dimensions, since the operation is scalable, and the element type must be a
243
260
scalar that matches the element type of the result.
244
261
262
+ An optional SSA value `padding` of the same elemental type as the MemRef is
263
+ provided to specify a fallback value in the case of masking.
264
+
265
+ An optional SSA value `mask` may be specified to mask out elements read
266
+ from the MemRef. The `mask` type is an `i1` vector with a shape that
267
+ matches how elements are read from the MemRef. Elements whose corresponding
268
+ mask element is `0` are masked out and replaced with `padding`.
269
+
270
+ If either `padding` or `mask` are specified, both must be specified.
271
+
245
272
Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
246
273
```mlir
247
274
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
@@ -256,10 +283,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
256
283
```mlir
257
284
%tile = arm_sme.tile_load %base[%c0, %c0] layout<horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
258
285
```
286
+
287
+ Example 4: Masked load of int 32-bit element ZA tile with horizontal layout (default) from memory.
288
+ ```mlir
289
+ %tile = arm_sme.tile_load %base[%c0, %c0], %pad, %mask : memref<?x?xf32>, vector<[4]x[4]xf32>
290
+ ```
259
291
}];
260
292
let arguments = (ins
261
293
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
262
294
Variadic<Index>:$indices,
295
+ Optional<AnyType>:$padding, Optional<AnyVector>:$mask,
263
296
ArmSME_TileSliceLayoutAttr:$layout
264
297
);
265
298
let results = (outs SMETile:$result);
@@ -273,9 +306,20 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
273
306
}
274
307
}];
275
308
309
+ let builders = [
310
+ OpBuilder<(ins "VectorType":$resultType, "Value":$base,
311
+ "ValueRange":$indices, "TileSliceLayout":$layout), [{
312
+ build($_builder, $_state, resultType, base, indices, {}, {}, layout);
313
+ }]>,
314
+ OpBuilder<(ins "VectorType":$resultType, "Value":$base,
315
+ "ValueRange":$indices), [{
316
+ build($_builder, $_state, resultType, base, indices, {}, {}, {});
317
+ }]>,
318
+ ];
319
+
276
320
let assemblyFormat =
277
- "$base `[` $indices `]` (`layout` ` ` $layout ^)? attr-dict "
278
- "`:` type($base) `,` type($result)";
321
+ "$base `[` $indices `]` (`,` $padding `, ` $mask ^)? (`layout` `` $layout^)? "
322
+ "attr-dict `:` type($base) `,` type($result)";
279
323
}
280
324
281
325
def TileStoreOp : ArmSME_Op<"tile_store"> {
0 commit comments