-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… #130944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Md Asghar Ahmad Shahid (shahidact) Changes…_reduce_matmul. This patch exposes broadcast and transpose semantics on 'batch_reduce_matmul'. This is the last one in continuation of other two variant of matmul ops. The broadcast and transpose semantic are as follows: Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so must include maps for all arguments if specified.
RFCs and related PR: Patch is 45.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130944.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index b44af2defc3e4..6344861c53ac5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1717,76 +1717,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: BZp
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: batch_reduce_matmul
- cpp_class_name: BatchReduceMatmulOp
- doc: |-
- Performs a batch-reduce matrix multiplication of two 3D inputs.
- The partial multiplication results are reduced into a 2D output.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
- iterator_types:
- - reduction
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matvec
cpp_class_name: MatvecOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e4dd458eaff84..5191a658bbf26 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -1054,6 +1054,137 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
}
+//===----------------------------------------------------------------------===//
+// Op definition for BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
+ AttrSizedOperandSegments,
+ LinalgContractionOpInterface]> {
+
+ let summary = [{Performs a batch-reduce matrix multiplication of two 3D inputs.
+The partial multiplication results are reduced into a 2D output.}];
+ let description = [{
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+
+ Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+ 'indexing_maps' as shown below. This is a list attribute, so must include maps for all
+ arguments if specified.
+
+ Example Transpose:
+ ```
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+
+ Example Broadcast:
+ ```
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+
+ Example Broadcast and Transpose:
+ ```
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ DefaultValuedOptionalAttr<
+ AffineMapArrayAttr,
+ "BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
+ >:$indexing_maps,
+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchReduceMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, BatchReduceMatmulOp::getRegionBuilder(),
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes,
+ inputs, outputs, attributes, BatchReduceMatmulOp::getRegionBuilder(),
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("cast", cast);
+ buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+ attributes, BatchReduceMatmulOp::getRegionBuilder(),
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>
+
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+ /// Implements the block region builder.
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+
+ /// Returns a list of AffineMap with the typical batch_reducematmul indexing charactristic.
+ static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+ /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+ bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
+
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+
+ // Generic methods.
+ static unsigned getNumRegionArgs();
+ std::string getLibraryCallName();
+ bool hasDynamicIndexingMaps() { return true; };
+ /// Returns true if the user defined indexing maps are not equal to default maps.
+ bool hasUserDefinedMaps();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..d46fbf988d762 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -218,6 +218,23 @@ static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
attributes, regionBuilder);
}
+static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder,
+ ArrayRef<AffineMap> indexingMaps) {
+ // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
+ SmallVector<Attribute, 4> indexingMapsAttrVal;
+ indexingMapsAttrVal =
+ llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ });
+ state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+ return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+ attributes, regionBuilder);
+}
+
/// Common parsing used for both named structured ops created by ods-gen and by
/// manually defined C++ ops. Does not handle regions.
static ParseResult
@@ -3464,19 +3481,24 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
return success();
}
-// Check general validity of input indexing map.
-static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
+// Check general validity of input indexing map of
+// BatchMatmulOp/BatchReduceMatmulOp.
+template <typename OpTy>
+static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
AffineMap opIndexingMap,
AffineMap defaultIndexingMap, bool isLHS) {
+ assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+ isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+ "Expected BatchMatmulOp or BatchReduceMatmulOp");
// Check the result dims are valid.
if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "Unexpected result dim expression (outside the set of default "
"result dims).";
// Check for valid number of result dims of input maps.
if (opIndexingMap.getNumResults() > 3)
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "no. of result dim expressions exceeds 3.";
auto hasValidBatchDim = [](AffineMap map) {
@@ -3486,60 +3508,83 @@ static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
// Check if the requested broadcast is valid.
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
- if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
- return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+ if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
+ return batchVariantMatmulOp->emitOpError()
+ << "Invalid broadcast requested.";
} else if (!hasValidBatchDim(opIndexingMap)) {
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "Invalid batch dimension expression.";
}
return success();
}
/// This function checks if the given AffineMap for the output of a
-/// BatchMatmulOp has exactly 3 result dimensions and if the output map result
-/// dimensions are valid.
-static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
+/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
+/// dimensions and if the output map result dimensions are valid.
+template <typename OpTy>
+static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
AffineMap opIndexingMap) {
- if (opIndexingMap.getNumResults() != 3)
- return batchMatmulOp->emitOpError()
+ assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+ isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+ "Expected BatchMatmulOp or BatchReduceMatmulOp");
+ if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
+ opIndexingMap.getNumResults() != 3) {
+
+ return batchVariantMatmulOp->emitOpError()
<< "expects 3 dims, but got (" << opIndexingMap.getNumResults()
<< ").";
+ }
+ if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
+ opIndexingMap.getNumResults() != 2) {
+ return batchVariantMatmulOp->emitOpError()
+ << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
+ << ").";
+ }
- auto areValidOutputResultDim = [](AffineMap outputMap) {
- return outputMap.getResult(0).isFunctionOfDim(0) &&
- outputMap.getResult(1).isFunctionOfDim(1) &&
- outputMap.getResult(2).isFunctionOfDim(2);
+ auto areValidOutputResultDim = [&](AffineMap outputMap) {
+ return isa<BatchMatmulOp>(batchVariantMatmulOp)
+ ? outputMap.getResult(0).isFunctionOfDim(0) &&
+ outputMap.getResult(1).isFunctionOfDim(1) &&
+ outputMap.getResult(2).isFunctionOfDim(2)
+ : outputMap.getResult(0).isFunctionOfDim(1) &&
+ outputMap.getResult(1).isFunctionOfDim(2);
};
- if (!areValidOutputResultDim(opIndexingMap))
- return batchMatmulOp->emitOpError()
+ if (!areValidOutputResultDim(opIndexingMap)) {
+ return batchVariantMatmulOp->emitOpError()
<< "Invalid output map result dimension.";
+ }
return success();
}
/// Verifies the broadcast and transpose semantic specified by the explicit
-/// indexing map for the BatchMatmulOp op for each operand specified by opIndex.
+/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
+/// specified by opIndex.
+template <typename OpTy>
static LogicalResult
-verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
- unsigned opIndex) {
+verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
+ unsigned opIndex) {
SmallVector<AffineMap, 3> opIndexingMaps =
- batchMatmulOp.getIndexingMapsArray();
+ batchVariantMatmulOp.getIndexingMapsArray();
SmallVector<AffineMap, 3> defaultIndexingMaps =
- batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+ batchVariantMatmulOp.getDefaultIndexingMaps(
+ batchVariantMatmulOp->getContext());
if (opIndexingMaps.size() != 3)
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "Indexing_map attribute must have 3 affine maps.";
auto opIndexingMap = opIndexingMaps[opIndex];
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
- if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap)))
+ if (opIndex == 2 &&
+ failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
return failure();
- if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
- opIndex == 0)))
+ if (opIndex != 2 &&
+ failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
+ defaultIndexingMap, opIndex == 0)))
return failure();
return success();
@@ -4035,7 +4080,7 @@ LogicalResult BatchMatmulOp::verify() {
return success();
for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
- if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
+ if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
return failure();
}
return success();
@@ -5340,6 +5385,167 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
}
};
+//===----------------------------------------------------------------------===//
+// BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{
+ utils::IteratorType::reduction, utils::IteratorType::parallel,
+ utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+SmallVector<AffineMap>
+BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+ AffineExpr d0, d1, d2, d3;
+ SmallVector<AffineMap> indexingMaps;
+ bindDims(context, d0, d1, d2, d3);
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
+ return indexingMaps;
+}
+
+unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchReduceMatmulOp::getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchReduceMatmulOp::hasUserDefinedMaps() {
+ SmallVector<AffineMap, 3> defaultMaps =
+ getDefaultIndexingMaps(this->getContext());
+ SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+ return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map bcastMap is valid for this op.
+bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
+ bool isLHS) {
+ assert(bcastMap.getNumResults() < 3 &&
+ "Expected less than 3 result dim expr.");
+ bool isValid = false;
+ enum Indices { batchPos, mPos, nPos, kPos };
+ if (bcastMap.getNumResults() == 1) {
+ AffineExpr exp = bcastMap.getResult(0);
+ isValid = exp.isFunctionOfDim(kPos);
+ } else if (bcastMap.getNumResults() == 2) {
+ AffineExpr exp0 = bcastMap.getResult(0);
+ AffineExpr exp1 = bcastMap.getResult(1);
+ isValid = isLHS
+ ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+ : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+ }
+ return isValid;
+}
+
+void BatchReduceMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 3 &&
+ "BatchReduceMatmulOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ auto toType = block.getArgument(2).getType();
+ Value castValA =
+ helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
+ Value castValB =
+ helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
+ Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
+ Value addVal =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+ yields.push_back(addVal);
+ helper.yieldOutputs(yields);
+}
+
+ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+ if (parser.parseLSquare())
+ ...
[truncated]
|
@llvm/pr-subscribers-mlir-linalg Author: Md Asghar Ahmad Shahid (shahidact) Changes…_reduce_matmul. This patch exposes broadcast and transpose semantics on 'batch_reduce_matmul'. This is the last one in continuation of other two variant of matmul ops. The broadcast and transpose semantic are as follows: Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so must include maps for all arguments if specified.
RFCs and related PR: Patch is 45.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130944.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index b44af2defc3e4..6344861c53ac5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1717,76 +1717,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: BZp
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: batch_reduce_matmul
- cpp_class_name: BatchReduceMatmulOp
- doc: |-
- Performs a batch-reduce matrix multiplication of two 3D inputs.
- The partial multiplication results are reduced into a 2D output.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
- iterator_types:
- - reduction
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matvec
cpp_class_name: MatvecOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e4dd458eaff84..5191a658bbf26 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -1054,6 +1054,137 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
}
+//===----------------------------------------------------------------------===//
+// Op definition for BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
+ AttrSizedOperandSegments,
+ LinalgContractionOpInterface]> {
+
+ let summary = [{Performs a batch-reduce matrix multiplication of two 3D inputs.
+The partial multiplication results are reduced into a 2D output.}];
+ let description = [{
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+
+ Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+ 'indexing_maps' as shown below. This is a list attribute, so must include maps for all
+ arguments if specified.
+
+ Example Transpose:
+ ```
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+
+ Example Broadcast:
+ ```
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+
+ Example Broadcast and Transpose:
+ ```
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ DefaultValuedOptionalAttr<
+ AffineMapArrayAttr,
+ "BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
+ >:$indexing_maps,
+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchReduceMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, BatchReduceMatmulOp::getRegionBuilder(),
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes,
+ inputs, outputs, attributes, BatchReduceMatmulOp::getRegionBuilder(),
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("cast", cast);
+ buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+ attributes, BatchReduceMatmulOp::getRegionBuilder(),
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>
+
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+ /// Implements the block region builder.
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+
+ /// Returns a list of AffineMap with the typical batch_reducematmul indexing charactristic.
+ static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+ /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+ bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
+
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+
+ // Generic methods.
+ static unsigned getNumRegionArgs();
+ std::string getLibraryCallName();
+ bool hasDynamicIndexingMaps() { return true; };
+ /// Returns true if the user defined indexing maps are not equal to default maps.
+ bool hasUserDefinedMaps();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..d46fbf988d762 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -218,6 +218,23 @@ static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
attributes, regionBuilder);
}
+static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder,
+ ArrayRef<AffineMap> indexingMaps) {
+ // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
+ SmallVector<Attribute, 4> indexingMapsAttrVal;
+ indexingMapsAttrVal =
+ llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ });
+ state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+ return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+ attributes, regionBuilder);
+}
+
/// Common parsing used for both named structured ops created by ods-gen and by
/// manually defined C++ ops. Does not handle regions.
static ParseResult
@@ -3464,19 +3481,24 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
return success();
}
-// Check general validity of input indexing map.
-static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
+// Check general validity of input indexing map of
+// BatchMatmulOp/BatchReduceMatmulOp.
+template <typename OpTy>
+static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
AffineMap opIndexingMap,
AffineMap defaultIndexingMap, bool isLHS) {
+ assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+ isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+ "Expected BatchMatmulOp or BatchReduceMatmulOp");
// Check the result dims are valid.
if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "Unexpected result dim expression (outside the set of default "
"result dims).";
// Check for valid number of result dims of input maps.
if (opIndexingMap.getNumResults() > 3)
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "no. of result dim expressions exceeds 3.";
auto hasValidBatchDim = [](AffineMap map) {
@@ -3486,60 +3508,83 @@ static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
// Check if the requested broadcast is valid.
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
- if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
- return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+ if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
+ return batchVariantMatmulOp->emitOpError()
+ << "Invalid broadcast requested.";
} else if (!hasValidBatchDim(opIndexingMap)) {
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "Invalid batch dimension expression.";
}
return success();
}
/// This function checks if the given AffineMap for the output of a
-/// BatchMatmulOp has exactly 3 result dimensions and if the output map result
-/// dimensions are valid.
-static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
+/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
+/// dimensions and if the output map result dimensions are valid.
+template <typename OpTy>
+static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
AffineMap opIndexingMap) {
- if (opIndexingMap.getNumResults() != 3)
- return batchMatmulOp->emitOpError()
+ assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+ isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+ "Expected BatchMatmulOp or BatchReduceMatmulOp");
+ if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
+ opIndexingMap.getNumResults() != 3) {
+
+ return batchVariantMatmulOp->emitOpError()
<< "expects 3 dims, but got (" << opIndexingMap.getNumResults()
<< ").";
+ }
+ if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
+ opIndexingMap.getNumResults() != 2) {
+ return batchVariantMatmulOp->emitOpError()
+ << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
+ << ").";
+ }
- auto areValidOutputResultDim = [](AffineMap outputMap) {
- return outputMap.getResult(0).isFunctionOfDim(0) &&
- outputMap.getResult(1).isFunctionOfDim(1) &&
- outputMap.getResult(2).isFunctionOfDim(2);
+ auto areValidOutputResultDim = [&](AffineMap outputMap) {
+ return isa<BatchMatmulOp>(batchVariantMatmulOp)
+ ? outputMap.getResult(0).isFunctionOfDim(0) &&
+ outputMap.getResult(1).isFunctionOfDim(1) &&
+ outputMap.getResult(2).isFunctionOfDim(2)
+ : outputMap.getResult(0).isFunctionOfDim(1) &&
+ outputMap.getResult(1).isFunctionOfDim(2);
};
- if (!areValidOutputResultDim(opIndexingMap))
- return batchMatmulOp->emitOpError()
+ if (!areValidOutputResultDim(opIndexingMap)) {
+ return batchVariantMatmulOp->emitOpError()
<< "Invalid output map result dimension.";
+ }
return success();
}
/// Verifies the broadcast and transpose semantic specified by the explicit
-/// indexing map for the BatchMatmulOp op for each operand specified by opIndex.
+/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
+/// specified by opIndex.
+template <typename OpTy>
static LogicalResult
-verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
- unsigned opIndex) {
+verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
+ unsigned opIndex) {
SmallVector<AffineMap, 3> opIndexingMaps =
- batchMatmulOp.getIndexingMapsArray();
+ batchVariantMatmulOp.getIndexingMapsArray();
SmallVector<AffineMap, 3> defaultIndexingMaps =
- batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+ batchVariantMatmulOp.getDefaultIndexingMaps(
+ batchVariantMatmulOp->getContext());
if (opIndexingMaps.size() != 3)
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "Indexing_map attribute must have 3 affine maps.";
auto opIndexingMap = opIndexingMaps[opIndex];
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
- if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap)))
+ if (opIndex == 2 &&
+ failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
return failure();
- if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
- opIndex == 0)))
+ if (opIndex != 2 &&
+ failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
+ defaultIndexingMap, opIndex == 0)))
return failure();
return success();
@@ -4035,7 +4080,7 @@ LogicalResult BatchMatmulOp::verify() {
return success();
for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
- if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
+ if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
return failure();
}
return success();
@@ -5340,6 +5385,167 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
}
};
+//===----------------------------------------------------------------------===//
+// BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{
+ utils::IteratorType::reduction, utils::IteratorType::parallel,
+ utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+SmallVector<AffineMap>
+BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+ AffineExpr d0, d1, d2, d3;
+ SmallVector<AffineMap> indexingMaps;
+ bindDims(context, d0, d1, d2, d3);
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
+ return indexingMaps;
+}
+
+unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchReduceMatmulOp::getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchReduceMatmulOp::hasUserDefinedMaps() {
+ SmallVector<AffineMap, 3> defaultMaps =
+ getDefaultIndexingMaps(this->getContext());
+ SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+ return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map bcastMap is valid for this op.
+bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
+ bool isLHS) {
+ assert(bcastMap.getNumResults() < 3 &&
+ "Expected less than 3 result dim expr.");
+ bool isValid = false;
+ enum Indices { batchPos, mPos, nPos, kPos };
+ if (bcastMap.getNumResults() == 1) {
+ AffineExpr exp = bcastMap.getResult(0);
+ isValid = exp.isFunctionOfDim(kPos);
+ } else if (bcastMap.getNumResults() == 2) {
+ AffineExpr exp0 = bcastMap.getResult(0);
+ AffineExpr exp1 = bcastMap.getResult(1);
+ isValid = isLHS
+ ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+ : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+ }
+ return isValid;
+}
+
+void BatchReduceMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 3 &&
+ "BatchReduceMatmulOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ auto toType = block.getArgument(2).getType();
+ Value castValA =
+ helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
+ Value castValB =
+ helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
+ Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
+ Value addVal =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+ yields.push_back(addVal);
+ helper.yieldOutputs(yields);
+}
+
+ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+ if (parser.parseLSquare())
+ ...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
42743b6
to
ab40dbe
Compare
Pinging for your kind attention :) @MaheshRavishankar @nicolasvasilache @banach-space and others |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution and apologies for the delay reviewing!
From what I can tell, the semantics match what we have already introduced for other Ops (e.g. linalg.matmul
+ linalg.contract
).
Since it's a Friday, I'm starting with the easier part - formatting 😅 I will go over C++ later.
Thanks for the ping. I skimmed through the change and I think it looks fine to me. I am not rely in a position to do a deep review soon-ish, but it seems consistent with the |
Thanks for your PR, in its current form I believe it will break the python side of things though. |
Thanks for the review, updated accordingly. |
Ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay — returning to this after a short break.
Most of my comments are related to test hygiene; they’re minor, but still worth addressing.
While reviewing invalid.mlir, I got a bit confused by the "invalid broadcast" case — hopefully you can clarify that (I probably just need to refresh my memory). On a related note: since the verification logic is shared, shouldn’t the invalid tests for linalg.batch_matmul
and linalg.reduce_batch_matmul
be nearly identical?
SmallVector<StringRef, 3> elidedAttrs = { | ||
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"}; | ||
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), | ||
elidedAttrs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I don't see this appearing in tests. What is it? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By default, printer exclude these attributes from printing. The explicit indexing_maps
are printed if it is not same as default indexing_maps
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me rephrase: what are linalg.memoized_indexing_maps
and when will these be printed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
memoized_indexing_maps
IIUC this is to cache the indexing_map array attribute. It gets printed on diagnostics. @nicolasvasilache could you pls enlighten us.
Hi @banach-space IMO, I addressed all the comment. |
Argh, sorry, I had un-submitted questions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. I would like my two questions to be addressed before merging, but they are non-blocking, so I'm approving as is.
Please allow at least a few more days before landing - it would be good to hear from other reviewers as well.
Thanks for the effort - great to see all these improvements to Linalg!
…_reduce_matmul. This patch exposes broadcast and transpose semantics on 'batch_reduce_matmul'. This is the last one in continuation of other two variant of matmul ops. The broadcast and transpose semantic are as follows: Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so must include maps for all arguments if specified. Example Transpose: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>) ``` Example Broadcast: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>) ``` Example Broadcast and Transpose: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>) ``` RFCs and related PR: https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149 https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863 https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586 llvm#115319 llvm#122275
-Added consistent variable and function naming in test cases. -Improved ops indexing_maps description.
for broadcast map from the existing one.
Hi @javedabsar1, @banach-space this PR is ready IMO, pls have a look. |
Sorry, I was away last week. I am OK with this being merged as is, we can iterate in-tree. Please wait for @javedabsar1 to approve before merging. Thanks for working on this! 🙏🏻 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks .
congrats @shahidact on this 2 month long journey 😄 |
Thanks, Enjoyed it :) |
…_reduce_matmul.
This patch exposes broadcast and transpose semantics on 'batch_reduce_matmul'. This is the last one in continuation of other two variant of matmul ops.
The broadcast and transpose semantic are as follows:
Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so must include maps for all arguments if specified.
RFCs and related PR:
https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149
https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863
https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586
#115319
#122275