diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 6a439bfb09078..a5725d6f1507e 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -858,7 +858,11 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz let arguments = (ins Variadic:$inputs, Variadic:$outputs, - DefaultValuedOptionalAttr:$indexing_maps + DefaultValuedOptionalAttr< + AffineMapArrayAttr, + "BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext())" + >:$indexing_maps, + DefaultValuedOptionalAttr:$cast ); let results = (outs Variadic:$result_tensors); let regions = (region AnyRegion:$region); @@ -884,9 +888,10 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, - CArg<"ArrayRef", "{}">:$attributes), + "Attribute":$cast, CArg<"ArrayRef", "{}">:$attributes), [{ $_state.addOperands(operands); + $_state.addAttribute("cast", cast); $_state.addAttributes(attributes); $_state.addTypes(resultTensorTypes); (void)$_state.addRegion(), diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index b756a67f3ba7a..42ea0e1197ef1 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -3951,11 +3951,18 @@ void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, RegionBuilderHelper helper(b, block); SmallVector yields; + TypeFn castVal = TypeFn::cast_signed; + auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) { + return attr.getName() == "cast"; + }); + if (castIter != attrs.end()) { + if (auto attr = llvm::dyn_cast(castIter->getValue())) + castVal = attr.getValue(); + } + auto toType = block.getArgument(2).getType(); - Value castValA = - helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0)); - Value castValB = - helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1)); + Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0)); + Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1)); Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB); Value addVal = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal); @@ -4004,11 +4011,6 @@ ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) { } void BatchMatmulOp::print(OpAsmPrinter &p) { - SmallVector elidedAttrs = { - "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"}; - ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), - elidedAttrs); - SmallVector indexingMaps = llvm::map_to_vector( BatchMatmulOp::getDefaultIndexingMaps(getContext()), [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); @@ -4018,6 +4020,11 @@ void BatchMatmulOp::print(OpAsmPrinter &p) { [&](Attribute attr) { p.printAttribute(attr); }); p << "]"; } + + SmallVector elidedAttrs = { + "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"}; + ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), + elidedAttrs); } /// Verify the user defined indexing maps. diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index 5cda4769d593f..c5fbb833ee399 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -149,7 +149,8 @@ def __init__( generic = region_op(GenericOp_, terminator=YieldOp) -def matmul( +def create_op( + op_type, *ins: Union[Operation, OpView, Value], outs: Sequence[Union[Operation, OpView, Value]], indexing_maps: Optional[Sequence[AffineMapAttr]] = None, @@ -161,7 +162,7 @@ def matmul( init = _get_op_result_or_value(outs[0]) result_types = [init.type] if isinstance(init.type, RankedTensorType) else [] - op = MatmulOp( + op = op_type( result_tensors=result_types, inputs=ins, outputs=[init], @@ -172,24 +173,32 @@ def matmul( return op +def matmul( + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + indexing_maps: Optional[Sequence[AffineMapAttr]] = None, + cast: Optional[Union[TypeFn, Attribute]] = None, +): + return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast) + + +def batch_matmul( + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + indexing_maps: Optional[Sequence[AffineMapAttr]] = None, + cast: Optional[Union[TypeFn, Attribute]] = None, +): + return create_op( + BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) + + def contract( *ins: Union[Operation, OpView, Value], outs: Sequence[Union[Operation, OpView, Value]], indexing_maps: Sequence[AffineMapAttr], cast: Optional[Union[TypeFn, Attribute]] = None, ): - ins = [_get_op_result_or_value(input) for input in ins] - if len(outs) > 1: - raise ValueError(f"{outs=} must have length 1.") - init = _get_op_result_or_value(outs[0]) - result_types = [init.type] if isinstance(init.type, RankedTensorType) else [] - - op = ContractOp( - result_tensors=result_types, - inputs=ins, - outputs=[init], - indexing_maps=indexing_maps, - cast=cast, + return create_op( + ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast ) - fill_builtin_region(op.operation) - return op diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index 8474eeac0db5b..1bd9c8825b05e 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -1497,7 +1497,7 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a // CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, // CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>, // CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) { -// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) // CHECK: return // CHECK: } func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) { @@ -1520,7 +1520,7 @@ func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %ar // CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>, // CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>, // CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) { -// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) // CHECK: return // CHECK: } func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) { @@ -1543,7 +1543,7 @@ func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref< // CHECK-SAME: %[[VAL_0:.*]]: memref<2x3x5xf32>, // CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>, // CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) { -// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) // CHECK: return // CHECK: } func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<2x3x7xf32>) { @@ -1566,7 +1566,7 @@ func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: // CHECK-SAME: %[[VAL_0:.*]]: memref<2x3x5xf32>, // CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>, // CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) { -// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) // CHECK: return // CHECK: } @@ -1622,7 +1622,7 @@ func.func @batch_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: me // CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>, // CHECK-SAME: %[[VAL_1:.*]]: memref<2x7x5xf32>, // CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) { -// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] +// CHECK: linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) // CHECK: return // CHECK: } func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) { diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index 94f8ea4faf4a8..307a88709ad52 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -466,3 +466,103 @@ def matmul_as_contract_op( ) print(module) + + +# CHECK-LABEL: TEST: testBatchMatmulOp +@run +def testBatchMatmulOp(): + with Context(), Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + a_shape = (2, 4, 8) + b_shape = (2, 8, 12) + b_transposed_shape = (2, 12, 8) + c_shape = (2, 4, 12) + + dimBatch = ir.AffineDimExpr.get(0) + dimM = ir.AffineDimExpr.get(1) + dimN = ir.AffineDimExpr.get(2) + dimK = ir.AffineDimExpr.get(3) + + # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> + # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> + # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + + a_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimK]) + b_transposed_map = ir.AffineMap.get(4, 0, [dimBatch, dimN, dimK]) + c_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimN]) + + # CHECK: func.func @batch_matmul_op( + @func.FuncOp.from_py_func( + # CHECK-SAME: %[[A:.*]]: tensor<2x4x8xf32>, + RankedTensorType.get(a_shape, f32), + # CHECK-SAME: %[[Amem:.*]]: memref<2x4x8xf32>, + MemRefType.get(a_shape, f32), + # CHECK-SAME: %[[B:.*]]: tensor<2x8x12xf32>, + RankedTensorType.get(b_shape, f32), + # CHECK-SAME: %[[Bmem:.*]]: memref<2x8x12xf32>, + MemRefType.get(b_shape, f32), + # CHECK-SAME: %[[BTrans:.*]]: tensor<2x12x8xf32>, + RankedTensorType.get(b_transposed_shape, f32), + # CHECK-SAME: %[[BTransmem:.*]]: memref<2x12x8xf32>, + MemRefType.get(b_transposed_shape, f32), + # CHECK-SAME: %[[C:.*]]: tensor<2x4x12xf32>, + RankedTensorType.get(c_shape, f32), + # CHECK-SAME: %[[Cmem:.*]]: memref<2x4x12xf32>) + MemRefType.get(c_shape, f32), + ) + def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem): + # CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>) + res = linalg.BatchMatmulOp( + result_tensors=(C.type,), + inputs=(A, B), + outputs=(C,), + ) + linalg.fill_builtin_region(res.operation) + # CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>) + res = linalg.batch_matmul(A, B, outs=(C,)) + + # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>) + res = linalg.BatchMatmulOp( + result_tensors=(C.type,), + inputs=(A, Btransposed), + outputs=(C,), + indexing_maps=[a_map, b_transposed_map, c_map], + ) + linalg.fill_builtin_region(res.operation) + # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>) + res = linalg.batch_matmul( + A, + Btransposed, + outs=(C,), + indexing_maps=[a_map, b_transposed_map, c_map], + ) + + # CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>) + res = linalg.BatchMatmulOp( + result_tensors=[], + inputs=(Amem, Bmem), + outputs=(Cmem,), + ) + linalg.fill_builtin_region(res.operation) + # CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>) + linalg.batch_matmul(Amem, Bmem, outs=(Cmem,)) + + # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>) + res = linalg.BatchMatmulOp( + result_tensors=[], + inputs=(Amem, Btransposedmem), + outputs=(Cmem,), + indexing_maps=[a_map, b_transposed_map, c_map], + ) + linalg.fill_builtin_region(res.operation) + # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>) + linalg.batch_matmul( + Amem, + Btransposedmem, + outs=(Cmem,), + indexing_maps=[a_map, b_transposed_map, c_map], + ) + + print(module)