Skip to content

[MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… #130944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 12, 2025

Conversation

shahidact
Copy link
Contributor

@shahidact shahidact commented Mar 12, 2025

…_reduce_matmul.

This patch exposes broadcast and transpose semantics on 'batch_reduce_matmul'. This is the last one in continuation of other two variant of matmul ops.

The broadcast and transpose semantic are as follows:

Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so must include maps for all arguments if specified.

Example Transpose:
```
linalg.batch_reduce_matmul indexing_maps = [
   affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
   affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
   affine_map<(d0, d1, d2, d3) -> (d1, d2)>
   ]
      ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
      outs(%arg2: memref<3x7xf32>)
```

Example Broadcast:
```
linalg.batch_reduce_matmul indexing_maps = [
   affine_map<(d0, d1, d2, d3) -> (d3)>,         // broadcast
   affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
   affine_map<(d0, d1, d2, d3) -> (d1, d2)>
   ]
      ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
      outs(%arg2: memref<3x7xf32>)
```

Example Broadcast and Transpose:
```
linalg.batch_reduce_matmul indexing_maps = [
   affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
   affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
   affine_map<(d0, d1, d2, d3) -> (d1, d2)>
   ]
      ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
      outs(%arg2: memref<3x7xf32>)
```

RFCs and related PR:
https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149
https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863
https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586
#115319
#122275

@llvmbot
Copy link
Member

llvmbot commented Mar 12, 2025

@llvm/pr-subscribers-mlir

Author: Md Asghar Ahmad Shahid (shahidact)

Changes

…_reduce_matmul.

This patch exposes broadcast and transpose semantics on 'batch_reduce_matmul'. This is the last one in continuation of other two variant of matmul ops.

The broadcast and transpose semantic are as follows:

Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so must include maps for all arguments if specified.

Example Transpose:
```
linalg.batch_reduce_matmul indexing_maps = [
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d0, d3, d1)&gt;, // transpose
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d0, d3, d2)&gt;,
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d1, d2)&gt;
   ]
      ins(%arg0, %arg1 : memref&lt;2x5x3xf32&gt;,memref&lt;2x5x7xf32&gt;)
      outs(%arg2: memref&lt;3x7xf32&gt;)
```

Example Broadcast:
```
linalg.batch_reduce_matmul indexing_maps = [
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d3)&gt;,         // broadcast
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d0, d3, d2)&gt;,
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d1, d2)&gt;
   ]
      ins(%arg0, %arg1 : memref&lt;5xf32&gt;, memref&lt;2x5x7xf32&gt;)
      outs(%arg2: memref&lt;3x7xf32&gt;)
```

Example Broadcast and Transpose:
```
linalg.batch_reduce_matmul indexing_maps = [
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d1, d3)&gt;,     // broadcast
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d0, d2, d3)&gt;, // transpose
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d1, d2)&gt;
   ]
      ins(%arg0, %arg1 : memref&lt;3x5xf32&gt;, memref&lt;2x7x5xf32&gt;)
      outs(%arg2: memref&lt;3x7xf32&gt;)
```

RFCs and related PR:
https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149 https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863 https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586 #115319
#122275


Patch is 45.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130944.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (-70)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+131)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+234-28)
  • (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+30)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+202)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+165)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index b44af2defc3e4..6344861c53ac5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1717,76 +1717,6 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_arg: BZp
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: batch_reduce_matmul
-  cpp_class_name: BatchReduceMatmulOp
-  doc: |-
-    Performs a batch-reduce matrix multiplication of two 3D inputs.
-    The partial multiplication results are reduced into a 2D output.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
-  iterator_types:
-  - reduction
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matvec
   cpp_class_name: MatvecOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e4dd458eaff84..5191a658bbf26 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -1054,6 +1054,137 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
 }
 
 
+//===----------------------------------------------------------------------===//
+// Op definition for BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
+                          AttrSizedOperandSegments,
+                          LinalgContractionOpInterface]> {
+    
+  let summary = [{Performs a batch-reduce matrix multiplication of two 3D inputs.
+The partial multiplication results are reduced into a 2D output.}];
+  let description = [{
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below. This is a list attribute, so must include maps for all
+    arguments if specified.
+
+    Example Transpose:
+    ```
+    linalg.batch_reduce_matmul indexing_maps = [
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                   affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                   ]
+                   ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+                   outs(%arg2: memref<3x7xf32>)
+    ```
+
+    Example Broadcast:
+    ```
+    linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,         // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+                     outs(%arg2: memref<3x7xf32>)
+    ```
+
+    Example Broadcast and Transpose:
+    ```
+    linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+                     outs(%arg2: memref<3x7xf32>)
+    ```
+    }];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<
+        AffineMapArrayAttr,
+        "BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
+      >:$indexing_maps,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchReduceMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, BatchReduceMatmulOp::getRegionBuilder(),
+          BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, BatchReduceMatmulOp::getRegionBuilder(),
+          BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+       "ValueRange":$outputs,
+       "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("cast", cast);
+        buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+          attributes, BatchReduceMatmulOp::getRegionBuilder(),
+          BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>
+      
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    let hasVerifier = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+      /// Implements the block region builder.
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+
+      /// Returns a list of AffineMap with the typical batch_reducematmul indexing charactristic.
+      static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+      /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+      bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
+
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      std::string getLibraryCallName();
+      bool hasDynamicIndexingMaps() { return true; };
+      /// Returns true if the user defined indexing maps are not equal to default maps.
+      bool hasUserDefinedMaps();
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..d46fbf988d762 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -218,6 +218,23 @@ static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
                            attributes, regionBuilder);
 }
 
+static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state,
+                                     std::optional<TypeRange> resultTensorTypes,
+                                     ValueRange inputs, ValueRange outputs,
+                                     ArrayRef<NamedAttribute> attributes,
+                                     RegionBuilderFn regionBuilder,
+                                     ArrayRef<AffineMap> indexingMaps) {
+  // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
+  SmallVector<Attribute, 4> indexingMapsAttrVal;
+  indexingMapsAttrVal =
+      llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+        return AffineMapAttr::get(map);
+      });
+  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+                           attributes, regionBuilder);
+}
+
 /// Common parsing used for both named structured ops created by ods-gen and by
 /// manually defined C++ ops. Does not handle regions.
 static ParseResult
@@ -3464,19 +3481,24 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
   return success();
 }
 
-// Check general validity of input indexing map.
-static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
+// Check general validity of input indexing map of
+// BatchMatmulOp/BatchReduceMatmulOp.
+template <typename OpTy>
+static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
                                      AffineMap opIndexingMap,
                                      AffineMap defaultIndexingMap, bool isLHS) {
+  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+          isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+         "Expected BatchMatmulOp or BatchReduceMatmulOp");
   // Check the result dims are valid.
   if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
-    return batchMatmulOp->emitOpError()
+    return batchVariantMatmulOp->emitOpError()
            << "Unexpected result dim expression (outside the set of default "
               "result dims).";
 
   // Check for valid number of result dims of input maps.
   if (opIndexingMap.getNumResults() > 3)
-    return batchMatmulOp->emitOpError()
+    return batchVariantMatmulOp->emitOpError()
            << "no. of result dim expressions exceeds 3.";
 
   auto hasValidBatchDim = [](AffineMap map) {
@@ -3486,60 +3508,83 @@ static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
 
   // Check if the requested broadcast is valid.
   if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
-    if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
-      return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+    if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
+      return batchVariantMatmulOp->emitOpError()
+             << "Invalid broadcast requested.";
   } else if (!hasValidBatchDim(opIndexingMap)) {
-    return batchMatmulOp->emitOpError()
+    return batchVariantMatmulOp->emitOpError()
            << "Invalid batch dimension expression.";
   }
   return success();
 }
 
 /// This function checks if the given AffineMap for the output of a
-/// BatchMatmulOp has exactly 3 result dimensions and if the output map result
-/// dimensions are valid.
-static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
+/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
+/// dimensions and if the output map result dimensions are valid.
+template <typename OpTy>
+static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
                                      AffineMap opIndexingMap) {
-  if (opIndexingMap.getNumResults() != 3)
-    return batchMatmulOp->emitOpError()
+  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+          isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+         "Expected BatchMatmulOp or BatchReduceMatmulOp");
+  if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
+      opIndexingMap.getNumResults() != 3) {
+
+    return batchVariantMatmulOp->emitOpError()
            << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
            << ").";
+  }
+  if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
+      opIndexingMap.getNumResults() != 2) {
+    return batchVariantMatmulOp->emitOpError()
+           << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
+           << ").";
+  }
 
-  auto areValidOutputResultDim = [](AffineMap outputMap) {
-    return outputMap.getResult(0).isFunctionOfDim(0) &&
-           outputMap.getResult(1).isFunctionOfDim(1) &&
-           outputMap.getResult(2).isFunctionOfDim(2);
+  auto areValidOutputResultDim = [&](AffineMap outputMap) {
+    return isa<BatchMatmulOp>(batchVariantMatmulOp)
+               ? outputMap.getResult(0).isFunctionOfDim(0) &&
+                     outputMap.getResult(1).isFunctionOfDim(1) &&
+                     outputMap.getResult(2).isFunctionOfDim(2)
+               : outputMap.getResult(0).isFunctionOfDim(1) &&
+                     outputMap.getResult(1).isFunctionOfDim(2);
   };
 
-  if (!areValidOutputResultDim(opIndexingMap))
-    return batchMatmulOp->emitOpError()
+  if (!areValidOutputResultDim(opIndexingMap)) {
+    return batchVariantMatmulOp->emitOpError()
            << "Invalid output map result dimension.";
+  }
 
   return success();
 }
 
 /// Verifies the broadcast and transpose semantic specified by the explicit
-/// indexing map for the BatchMatmulOp op for each operand specified by opIndex.
+/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
+/// specified by opIndex.
+template <typename OpTy>
 static LogicalResult
-verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
-                                  unsigned opIndex) {
+verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
+                                         unsigned opIndex) {
   SmallVector<AffineMap, 3> opIndexingMaps =
-      batchMatmulOp.getIndexingMapsArray();
+      batchVariantMatmulOp.getIndexingMapsArray();
   SmallVector<AffineMap, 3> defaultIndexingMaps =
-      batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+      batchVariantMatmulOp.getDefaultIndexingMaps(
+          batchVariantMatmulOp->getContext());
 
   if (opIndexingMaps.size() != 3)
-    return batchMatmulOp->emitOpError()
+    return batchVariantMatmulOp->emitOpError()
            << "Indexing_map attribute must have 3 affine maps.";
 
   auto opIndexingMap = opIndexingMaps[opIndex];
   auto defaultIndexingMap = defaultIndexingMaps[opIndex];
 
-  if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap)))
+  if (opIndex == 2 &&
+      failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
     return failure();
 
-  if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
-                             opIndex == 0)))
+  if (opIndex != 2 &&
+      failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
+                             defaultIndexingMap, opIndex == 0)))
     return failure();
 
   return success();
@@ -4035,7 +4080,7 @@ LogicalResult BatchMatmulOp::verify() {
     return success();
 
   for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
-    if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
+    if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
       return failure();
   }
   return success();
@@ -5340,6 +5385,167 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
+  return SmallVector<utils::IteratorType>{
+      utils::IteratorType::reduction, utils::IteratorType::parallel,
+      utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+SmallVector<AffineMap>
+BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+  AffineExpr d0, d1, d2, d3;
+  SmallVector<AffineMap> indexingMaps;
+  bindDims(context, d0, d1, d2, d3);
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+  indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
+  return indexingMaps;
+}
+
+unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchReduceMatmulOp::getLibraryCallName() {
+  return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchReduceMatmulOp::hasUserDefinedMaps() {
+  SmallVector<AffineMap, 3> defaultMaps =
+      getDefaultIndexingMaps(this->getContext());
+  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+  return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map bcastMap is valid for this op.
+bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
+                                                    bool isLHS) {
+  assert(bcastMap.getNumResults() < 3 &&
+         "Expected less than 3 result dim expr.");
+  bool isValid = false;
+  enum Indices { batchPos, mPos, nPos, kPos };
+  if (bcastMap.getNumResults() == 1) {
+    AffineExpr exp = bcastMap.getResult(0);
+    isValid = exp.isFunctionOfDim(kPos);
+  } else if (bcastMap.getNumResults() == 2) {
+    AffineExpr exp0 = bcastMap.getResult(0);
+    AffineExpr exp1 = bcastMap.getResult(1);
+    isValid = isLHS
+                  ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+                  : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+  }
+  return isValid;
+}
+
+void BatchReduceMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                                        ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "BatchReduceMatmulOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+
+  auto toType = block.getArgument(2).getType();
+  Value castValA =
+      helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
+  Value castValB =
+      helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
+  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
+  Value addVal =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+  yields.push_back(addVal);
+  helper.yieldOutputs(yields);
+}
+
+ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
+                                       OperationState &result) {
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+    if (parser.parseLSquare())
+ ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 12, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Md Asghar Ahmad Shahid (shahidact)

Changes

…_reduce_matmul.

This patch exposes broadcast and transpose semantics on 'batch_reduce_matmul'. This is the last one in continuation of other two variant of matmul ops.

The broadcast and transpose semantic are as follows:

Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so must include maps for all arguments if specified.

Example Transpose:
```
linalg.batch_reduce_matmul indexing_maps = [
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d0, d3, d1)&gt;, // transpose
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d0, d3, d2)&gt;,
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d1, d2)&gt;
   ]
      ins(%arg0, %arg1 : memref&lt;2x5x3xf32&gt;,memref&lt;2x5x7xf32&gt;)
      outs(%arg2: memref&lt;3x7xf32&gt;)
```

Example Broadcast:
```
linalg.batch_reduce_matmul indexing_maps = [
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d3)&gt;,         // broadcast
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d0, d3, d2)&gt;,
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d1, d2)&gt;
   ]
      ins(%arg0, %arg1 : memref&lt;5xf32&gt;, memref&lt;2x5x7xf32&gt;)
      outs(%arg2: memref&lt;3x7xf32&gt;)
```

Example Broadcast and Transpose:
```
linalg.batch_reduce_matmul indexing_maps = [
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d1, d3)&gt;,     // broadcast
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d0, d2, d3)&gt;, // transpose
   affine_map&lt;(d0, d1, d2, d3) -&gt; (d1, d2)&gt;
   ]
      ins(%arg0, %arg1 : memref&lt;3x5xf32&gt;, memref&lt;2x7x5xf32&gt;)
      outs(%arg2: memref&lt;3x7xf32&gt;)
```

RFCs and related PR:
https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149 https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863 https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586 #115319
#122275


Patch is 45.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130944.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (-70)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+131)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+234-28)
  • (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+30)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+202)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+165)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index b44af2defc3e4..6344861c53ac5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1717,76 +1717,6 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_arg: BZp
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: batch_reduce_matmul
-  cpp_class_name: BatchReduceMatmulOp
-  doc: |-
-    Performs a batch-reduce matrix multiplication of two 3D inputs.
-    The partial multiplication results are reduced into a 2D output.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
-  iterator_types:
-  - reduction
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matvec
   cpp_class_name: MatvecOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e4dd458eaff84..5191a658bbf26 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -1054,6 +1054,137 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
 }
 
 
+//===----------------------------------------------------------------------===//
+// Op definition for BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
+                          AttrSizedOperandSegments,
+                          LinalgContractionOpInterface]> {
+    
+  let summary = [{Performs a batch-reduce matrix multiplication of two 3D inputs.
+The partial multiplication results are reduced into a 2D output.}];
+  let description = [{
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below. This is a list attribute, so must include maps for all
+    arguments if specified.
+
+    Example Transpose:
+    ```
+    linalg.batch_reduce_matmul indexing_maps = [
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                   affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                   ]
+                   ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+                   outs(%arg2: memref<3x7xf32>)
+    ```
+
+    Example Broadcast:
+    ```
+    linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,         // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+                     outs(%arg2: memref<3x7xf32>)
+    ```
+
+    Example Broadcast and Transpose:
+    ```
+    linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+                     outs(%arg2: memref<3x7xf32>)
+    ```
+    }];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<
+        AffineMapArrayAttr,
+        "BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
+      >:$indexing_maps,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchReduceMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, BatchReduceMatmulOp::getRegionBuilder(),
+          BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, BatchReduceMatmulOp::getRegionBuilder(),
+          BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+       "ValueRange":$outputs,
+       "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("cast", cast);
+        buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+          attributes, BatchReduceMatmulOp::getRegionBuilder(),
+          BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>
+      
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    let hasVerifier = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+      /// Implements the block region builder.
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+
+      /// Returns a list of AffineMap with the typical batch_reducematmul indexing charactristic.
+      static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+      /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+      bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
+
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      std::string getLibraryCallName();
+      bool hasDynamicIndexingMaps() { return true; };
+      /// Returns true if the user defined indexing maps are not equal to default maps.
+      bool hasUserDefinedMaps();
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..d46fbf988d762 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -218,6 +218,23 @@ static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
                            attributes, regionBuilder);
 }
 
+static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state,
+                                     std::optional<TypeRange> resultTensorTypes,
+                                     ValueRange inputs, ValueRange outputs,
+                                     ArrayRef<NamedAttribute> attributes,
+                                     RegionBuilderFn regionBuilder,
+                                     ArrayRef<AffineMap> indexingMaps) {
+  // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
+  SmallVector<Attribute, 4> indexingMapsAttrVal;
+  indexingMapsAttrVal =
+      llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+        return AffineMapAttr::get(map);
+      });
+  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+                           attributes, regionBuilder);
+}
+
 /// Common parsing used for both named structured ops created by ods-gen and by
 /// manually defined C++ ops. Does not handle regions.
 static ParseResult
@@ -3464,19 +3481,24 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
   return success();
 }
 
-// Check general validity of input indexing map.
-static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
+// Check general validity of input indexing map of
+// BatchMatmulOp/BatchReduceMatmulOp.
+template <typename OpTy>
+static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
                                      AffineMap opIndexingMap,
                                      AffineMap defaultIndexingMap, bool isLHS) {
+  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+          isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+         "Expected BatchMatmulOp or BatchReduceMatmulOp");
   // Check the result dims are valid.
   if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
-    return batchMatmulOp->emitOpError()
+    return batchVariantMatmulOp->emitOpError()
            << "Unexpected result dim expression (outside the set of default "
               "result dims).";
 
   // Check for valid number of result dims of input maps.
   if (opIndexingMap.getNumResults() > 3)
-    return batchMatmulOp->emitOpError()
+    return batchVariantMatmulOp->emitOpError()
            << "no. of result dim expressions exceeds 3.";
 
   auto hasValidBatchDim = [](AffineMap map) {
@@ -3486,60 +3508,83 @@ static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
 
   // Check if the requested broadcast is valid.
   if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
-    if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
-      return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+    if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
+      return batchVariantMatmulOp->emitOpError()
+             << "Invalid broadcast requested.";
   } else if (!hasValidBatchDim(opIndexingMap)) {
-    return batchMatmulOp->emitOpError()
+    return batchVariantMatmulOp->emitOpError()
            << "Invalid batch dimension expression.";
   }
   return success();
 }
 
 /// This function checks if the given AffineMap for the output of a
-/// BatchMatmulOp has exactly 3 result dimensions and if the output map result
-/// dimensions are valid.
-static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
+/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
+/// dimensions and if the output map result dimensions are valid.
+template <typename OpTy>
+static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
                                      AffineMap opIndexingMap) {
-  if (opIndexingMap.getNumResults() != 3)
-    return batchMatmulOp->emitOpError()
+  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+          isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+         "Expected BatchMatmulOp or BatchReduceMatmulOp");
+  if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
+      opIndexingMap.getNumResults() != 3) {
+
+    return batchVariantMatmulOp->emitOpError()
            << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
            << ").";
+  }
+  if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
+      opIndexingMap.getNumResults() != 2) {
+    return batchVariantMatmulOp->emitOpError()
+           << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
+           << ").";
+  }
 
-  auto areValidOutputResultDim = [](AffineMap outputMap) {
-    return outputMap.getResult(0).isFunctionOfDim(0) &&
-           outputMap.getResult(1).isFunctionOfDim(1) &&
-           outputMap.getResult(2).isFunctionOfDim(2);
+  auto areValidOutputResultDim = [&](AffineMap outputMap) {
+    return isa<BatchMatmulOp>(batchVariantMatmulOp)
+               ? outputMap.getResult(0).isFunctionOfDim(0) &&
+                     outputMap.getResult(1).isFunctionOfDim(1) &&
+                     outputMap.getResult(2).isFunctionOfDim(2)
+               : outputMap.getResult(0).isFunctionOfDim(1) &&
+                     outputMap.getResult(1).isFunctionOfDim(2);
   };
 
-  if (!areValidOutputResultDim(opIndexingMap))
-    return batchMatmulOp->emitOpError()
+  if (!areValidOutputResultDim(opIndexingMap)) {
+    return batchVariantMatmulOp->emitOpError()
            << "Invalid output map result dimension.";
+  }
 
   return success();
 }
 
 /// Verifies the broadcast and transpose semantic specified by the explicit
-/// indexing map for the BatchMatmulOp op for each operand specified by opIndex.
+/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
+/// specified by opIndex.
+template <typename OpTy>
 static LogicalResult
-verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
-                                  unsigned opIndex) {
+verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
+                                         unsigned opIndex) {
   SmallVector<AffineMap, 3> opIndexingMaps =
-      batchMatmulOp.getIndexingMapsArray();
+      batchVariantMatmulOp.getIndexingMapsArray();
   SmallVector<AffineMap, 3> defaultIndexingMaps =
-      batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+      batchVariantMatmulOp.getDefaultIndexingMaps(
+          batchVariantMatmulOp->getContext());
 
   if (opIndexingMaps.size() != 3)
-    return batchMatmulOp->emitOpError()
+    return batchVariantMatmulOp->emitOpError()
            << "Indexing_map attribute must have 3 affine maps.";
 
   auto opIndexingMap = opIndexingMaps[opIndex];
   auto defaultIndexingMap = defaultIndexingMaps[opIndex];
 
-  if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap)))
+  if (opIndex == 2 &&
+      failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
     return failure();
 
-  if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
-                             opIndex == 0)))
+  if (opIndex != 2 &&
+      failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
+                             defaultIndexingMap, opIndex == 0)))
     return failure();
 
   return success();
@@ -4035,7 +4080,7 @@ LogicalResult BatchMatmulOp::verify() {
     return success();
 
   for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
-    if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
+    if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
       return failure();
   }
   return success();
@@ -5340,6 +5385,167 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
+  return SmallVector<utils::IteratorType>{
+      utils::IteratorType::reduction, utils::IteratorType::parallel,
+      utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+SmallVector<AffineMap>
+BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+  AffineExpr d0, d1, d2, d3;
+  SmallVector<AffineMap> indexingMaps;
+  bindDims(context, d0, d1, d2, d3);
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+  indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
+  return indexingMaps;
+}
+
+unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchReduceMatmulOp::getLibraryCallName() {
+  return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchReduceMatmulOp::hasUserDefinedMaps() {
+  SmallVector<AffineMap, 3> defaultMaps =
+      getDefaultIndexingMaps(this->getContext());
+  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+  return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map bcastMap is valid for this op.
+bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
+                                                    bool isLHS) {
+  assert(bcastMap.getNumResults() < 3 &&
+         "Expected less than 3 result dim expr.");
+  bool isValid = false;
+  enum Indices { batchPos, mPos, nPos, kPos };
+  if (bcastMap.getNumResults() == 1) {
+    AffineExpr exp = bcastMap.getResult(0);
+    isValid = exp.isFunctionOfDim(kPos);
+  } else if (bcastMap.getNumResults() == 2) {
+    AffineExpr exp0 = bcastMap.getResult(0);
+    AffineExpr exp1 = bcastMap.getResult(1);
+    isValid = isLHS
+                  ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+                  : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+  }
+  return isValid;
+}
+
+void BatchReduceMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                                        ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "BatchReduceMatmulOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+
+  auto toType = block.getArgument(2).getType();
+  Value castValA =
+      helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
+  Value castValB =
+      helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
+  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
+  Value addVal =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+  yields.push_back(addVal);
+  helper.yieldOutputs(yields);
+}
+
+ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
+                                       OperationState &result) {
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+    if (parser.parseLSquare())
+ ...
[truncated]

Copy link

github-actions bot commented Mar 12, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@shahidact shahidact force-pushed the br-matmul branch 2 times, most recently from 42743b6 to ab40dbe Compare March 12, 2025 14:04
@shahidact
Copy link
Contributor Author

Pinging for your kind attention :) @MaheshRavishankar @nicolasvasilache @banach-space and others

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution and apologies for the delay reviewing!

From what I can tell, the semantics match what we have already introduced for other Ops (e.g. linalg.matmul + linalg.contract).

Since it's a Friday, I'm starting with the easier part - formatting 😅 I will go over C++ later.

@MaheshRavishankar
Copy link
Contributor

Thanks for the ping. I skimmed through the change and I think it looks fine to me. I am not rely in a position to do a deep review soon-ish, but it seems consistent with the linalg.contract.

@nicolasvasilache
Copy link
Contributor

Thanks for your PR, in its current form I believe it will break the python side of things though.
Please update and add a test on the python side if you remove the opdsl / yaml source of truth.

@llvmbot llvmbot added the mlir:python MLIR Python bindings label Apr 13, 2025
@shahidact
Copy link
Contributor Author

Thanks for your PR, in its current form I believe it will break the python side of things though. Please update and add a test on the python side if you remove the opdsl / yaml source of truth.

Thanks for the review, updated accordingly.

@rengolin
Copy link
Member

Ping

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay — returning to this after a short break.

Most of my comments are related to test hygiene; they’re minor, but still worth addressing.

While reviewing invalid.mlir, I got a bit confused by the "invalid broadcast" case — hopefully you can clarify that (I probably just need to refresh my memory). On a related note: since the verification logic is shared, shouldn’t the invalid tests for linalg.batch_matmul and linalg.reduce_batch_matmul be nearly identical?

Comment on lines +5511 to +5548
SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
elidedAttrs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I don't see this appearing in tests. What is it? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default, printer exclude these attributes from printing. The explicit indexing_maps are printed if it is not same as default indexing_maps.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me rephrase: what are linalg.memoized_indexing_maps and when will these be printed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memoized_indexing_maps

IIUC this is to cache the indexing_map array attribute. It gets printed on diagnostics. @nicolasvasilache could you pls enlighten us.

@shahidact
Copy link
Contributor Author

Hi @banach-space IMO, I addressed all the comment.

@banach-space
Copy link
Contributor

Hi @banach-space IMO, I addressed all the comment.

Argh, sorry, I had un-submitted questions.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. I would like my two questions to be addressed before merging, but they are non-blocking, so I'm approving as is.

Please allow at least a few more days before landing - it would be good to hear from other reviewers as well.

Thanks for the effort - great to see all these improvements to Linalg!

…_reduce_matmul.

This patch exposes broadcast and transpose semantics on
'batch_reduce_matmul'. This is the last one in continuation of
other two variant of matmul ops.

The broadcast and transpose semantic are as follows:

Broadcast and Transpose semantics can be appiled by specifying the
explicit attribute 'indexing_maps' as shown below. This is a list
attribute, so must include maps for all arguments if specified.

    Example Transpose:
    ```
    linalg.batch_reduce_matmul indexing_maps = [
       affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
       ]
          ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
          outs(%arg2: memref<3x7xf32>)
    ```

    Example Broadcast:
    ```
    linalg.batch_reduce_matmul indexing_maps = [
       affine_map<(d0, d1, d2, d3) -> (d3)>,         // broadcast
       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
       ]
          ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
          outs(%arg2: memref<3x7xf32>)
    ```

    Example Broadcast and Transpose:
    ```
    linalg.batch_reduce_matmul indexing_maps = [
       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
       ]
          ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
          outs(%arg2: memref<3x7xf32>)
    ```

RFCs and related PR:
https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149
https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863
https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586
llvm#115319
llvm#122275
-Added consistent variable and function naming in test cases.
-Improved ops indexing_maps description.
@shahidact
Copy link
Contributor Author

Hi @javedabsar1, @banach-space this PR is ready IMO, pls have a look.

@banach-space
Copy link
Contributor

Hi @javedabsar1, @banach-space this PR is ready IMO, pls have a look.

Sorry, I was away last week. I am OK with this being merged as is, we can iterate in-tree. Please wait for @javedabsar1 to approve before merging.

Thanks for working on this! 🙏🏻

Copy link
Contributor

@javedabsar1 javedabsar1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks .

@rengolin rengolin merged commit d78ff5f into llvm:main May 12, 2025
11 checks passed
@makslevental
Copy link
Contributor

congrats @shahidact on this 2 month long journey 😄

@shahidact
Copy link
Contributor Author

congrats @shahidact on this 2 month long journey 😄

Thanks, Enjoyed it :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants