@@ -690,34 +690,32 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
690
690
691
691
Example Transpose:
692
692
```mlir
693
- linalg.matmul indexing_maps = [
694
- affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
695
- affine_map<(d0, d1, d2) -> (d2, d1)>,
696
- affine_map<(d0, d1, d2) -> (d0, d1)>
697
- ]
698
- ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
699
- outs(%arg2: memref<3x7xf32>)
693
+ linalg.matmul
694
+ indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose
695
+ affine_map<(m, n, k) -> (k, n)>,
696
+ affine_map<(m, n, k) -> (m, n)>]
697
+ ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
698
+ outs(%arg2: memref<3x7xf32>)
700
699
```
701
700
702
701
Example Broadcast:
703
- ```mlir
704
- linalg.matmul indexing_maps = [
705
- affine_map<(d0, d1, d2) -> (d2)>, // broadcast
706
- affine_map<(d0, d1, d2) -> (d2, d1)>,
707
- affine_map<(d0, d1, d2) -> (d0, d1)>
708
- ]
709
- ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
710
- outs(%arg2: memref<3x7xf32>)
702
+ ```mlir
703
+ linalg.matmul
704
+ indexing_maps = [affine_map<(m, n, k) -> (k)>, // broadcast
705
+ affine_map<(m, n, k) -> (k, n)>,
706
+ affine_map<(m, n, k) -> (m, n)>]
707
+ ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
708
+ outs(%arg2: memref<3x7xf32>)
711
709
```
712
710
713
711
Example Broadcast and transpose:
714
712
```mlir
715
- linalg.matmul indexing_maps = [
716
- affine_map<(d0, d1, d2 ) -> (d2, d0 )>, // transpose
717
- affine_map<(d0, d1, d2 ) -> (d2 )>, // broadcast
718
- affine_map<(d0, d1, d2 ) -> (d0, d1)>
719
- ]
720
- ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
713
+ linalg.matmul
714
+ indexing_maps = [ affine_map<(m, n, k ) -> (k, m )>, // transpose
715
+ affine_map<(m, n, k ) -> (k )>, // broadcast
716
+ affine_map<(m, n, k ) -> (m, n)>]
717
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>)
718
+ outs(%arg2: memref<3x7xf32>)
721
719
```
722
720
}];
723
721
@@ -775,7 +773,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
775
773
static void regionBuilder(ImplicitLocOpBuilder &b,
776
774
Block &block, ArrayRef<NamedAttribute> attrs);
777
775
778
- /// Returns a list of AffineMap with the typical matmul indexing charactristic.
776
+ /// Returns a list of AffineMap with the default matmul indexing charactristic.
779
777
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
780
778
781
779
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
@@ -954,35 +952,32 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
954
952
955
953
Example Transpose:
956
954
```mlir
957
- linalg.batch_matmul indexing_maps = [
958
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
959
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
960
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
961
- ]
962
- ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
963
- outs(%arg2: memref<2x3x7xf32>)
955
+ linalg.batch_matmul
956
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose
957
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
958
+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
959
+ ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
960
+ outs(%arg2: memref<2x3x7xf32>)
964
961
```
965
962
966
963
Example Broadcast:
967
964
```mlir
968
- linalg.batch_matmul indexing_maps = [
969
- affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
970
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
971
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
972
- ]
973
- ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
974
- outs(%arg2: memref<2x3x7xf32>)
965
+ linalg.batch_matmul
966
+ indexing_maps = [affine_map<(batch, m, n, k) -> (k)>, // broadcast
967
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
968
+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
969
+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
970
+ outs(%arg2: memref<2x3x7xf32>)
975
971
```
976
972
977
973
Example Broadcast and Transpose:
978
974
```mlir
979
- linalg.batch_matmul indexing_maps = [
980
- affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
981
- affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
982
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
983
- ]
984
- ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
985
- outs(%arg2: memref<2x3x7xf32>)
975
+ linalg.batch_matmul
976
+ indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>, // broadcast
977
+ affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose
978
+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
979
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
980
+ outs(%arg2: memref<2x3x7xf32>)
986
981
```
987
982
}];
988
983
@@ -1065,6 +1060,134 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
1065
1060
}
1066
1061
1067
1062
1063
+ //===----------------------------------------------------------------------===//
1064
+ // Op definition for BatchReduceMatmulOp
1065
+ //===----------------------------------------------------------------------===//
1066
+
1067
+ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
1068
+ AttrSizedOperandSegments,
1069
+ LinalgContractionOpInterface]> {
1070
+
1071
+ let summary = [{Performs a batch-reduce matrix multiplication on two inputs.
1072
+ The partial multiplication results are reduced into a 2D output.}];
1073
+ let description = [{
1074
+ Numeric casting is performed on the operands to the inner multiply,
1075
+ promoting them to the same data type as the accumulator/output.
1076
+
1077
+ Broadcast and Transpose semantics can be applied by specifying the explicit attribute
1078
+ 'indexing_maps' as shown below. This is a list attribute, so must include maps for all
1079
+ arguments if specified.
1080
+
1081
+ Example Transpose:
1082
+ ```mlir
1083
+ linalg.batch_reduce_matmul
1084
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose
1085
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
1086
+ affine_map<(batch, m, n, k) -> (m, n)>]
1087
+ ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
1088
+ outs(%arg2: memref<3x7xf32>)
1089
+ ```
1090
+
1091
+ Example Broadcast:
1092
+ ```mlir
1093
+ linalg.batch_reduce_matmul
1094
+ indexing_maps = [affine_map<(batch, m, n, k) -> (k)>, // broadcast
1095
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
1096
+ affine_map<(batch, m, n, k) -> (m, n)>]
1097
+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
1098
+ outs(%arg2: memref<3x7xf32>)
1099
+ ```
1100
+
1101
+ Example Broadcast and Transpose:
1102
+ ```mlir
1103
+ linalg.batch_reduce_matmul
1104
+ indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>, // broadcast
1105
+ affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose
1106
+ affine_map<(batch, m, n, k) -> (m, n)>]
1107
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
1108
+ outs(%arg2: memref<3x7xf32>)
1109
+ ```
1110
+ }];
1111
+
1112
+ let arguments = (ins
1113
+ Variadic<AnyType>:$inputs,
1114
+ Variadic<AnyShaped>:$outputs,
1115
+ DefaultValuedOptionalAttr<
1116
+ AffineMapArrayAttr,
1117
+ "BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
1118
+ >:$indexing_maps,
1119
+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
1120
+ );
1121
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
1122
+ let regions = (region AnyRegion:$region);
1123
+
1124
+ let skipDefaultBuilders = 1;
1125
+ let builders = [
1126
+ OpBuilder<
1127
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
1128
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
1129
+ [{
1130
+ buildBatchReduceMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
1131
+ attributes, BatchReduceMatmulOp::getRegionBuilder(),
1132
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
1133
+ }]>,
1134
+ OpBuilder<
1135
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1136
+ "ValueRange":$outputs,
1137
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
1138
+ [{
1139
+ buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes,
1140
+ inputs, outputs, attributes, BatchReduceMatmulOp::getRegionBuilder(),
1141
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
1142
+ }]>,
1143
+ OpBuilder<
1144
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1145
+ "ValueRange":$outputs,
1146
+ "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
1147
+ [{
1148
+ $_state.addAttribute("cast", cast);
1149
+ buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
1150
+ attributes, BatchReduceMatmulOp::getRegionBuilder(),
1151
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
1152
+ }]>
1153
+
1154
+ ];
1155
+ let hasCustomAssemblyFormat = 1;
1156
+ let hasFolder = 1;
1157
+ let hasVerifier = 1;
1158
+
1159
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
1160
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
1161
+
1162
+ /// Implements the block region builder.
1163
+ static void regionBuilder(ImplicitLocOpBuilder &b,
1164
+ Block &block, ArrayRef<NamedAttribute> attrs);
1165
+
1166
+ /// Returns a list of AffineMap with the default batch_reduce_matmul indexing charactristic.
1167
+ static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
1168
+
1169
+ /// Returns true if the given broadcast map \p bcastMap is valid for this op.
1170
+ bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
1171
+
1172
+ static std::function<void(ImplicitLocOpBuilder &,
1173
+ Block &, ArrayRef<NamedAttribute>)>
1174
+ getRegionBuilder() {
1175
+ return regionBuilder;
1176
+ }
1177
+
1178
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
1179
+ return getOutputsMutable();
1180
+ }
1181
+
1182
+ // Generic methods.
1183
+ static unsigned getNumRegionArgs();
1184
+ std::string getLibraryCallName();
1185
+ bool hasDynamicIndexingMaps() { return true; };
1186
+ /// Returns true if the user defined indexing maps are not equal to default maps.
1187
+ bool hasUserDefinedMaps();
1188
+ }];
1189
+ }
1190
+
1068
1191
//===----------------------------------------------------------------------===//
1069
1192
// Named Linalg ops, implemented as a declarative configurations of generic ops.
1070
1193
//===----------------------------------------------------------------------===//
0 commit comments