Skip to content

Commit d78ff5f

Browse files
authored
[MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… (llvm#130944)
…_reduce_matmul. This patch exposes broadcast and transpose semantics on 'batch_reduce_matmul'. This is the last one in continuation of other two variant of matmul ops. The broadcast and transpose semantic are as follows: Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so must include maps for all arguments if specified. Example Transpose: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>) ``` Example Broadcast: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>) ``` Example Broadcast and Transpose: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>) ``` RFCs and related PR: https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149 https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863 https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586 llvm#115319 llvm#122275
1 parent 3ffde4a commit d78ff5f

File tree

8 files changed

+950
-156
lines changed

8 files changed

+950
-156
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,76 +1717,6 @@ structured_op: !LinalgStructuredOpConfig
17171717
- !ScalarExpression
17181718
scalar_arg: BZp
17191719
--- !LinalgOpConfig
1720-
metadata: !LinalgOpMetadata
1721-
name: batch_reduce_matmul
1722-
cpp_class_name: BatchReduceMatmulOp
1723-
doc: |-
1724-
Performs a batch-reduce matrix multiplication of two 3D inputs.
1725-
The partial multiplication results are reduced into a 2D output.
1726-
1727-
Numeric casting is performed on the operands to the inner multiply, promoting
1728-
them to the same data type as the accumulator/output.
1729-
implements:
1730-
- LinalgContractionOpInterface
1731-
structured_op: !LinalgStructuredOpConfig
1732-
args:
1733-
- !LinalgOperandDefConfig
1734-
name: A
1735-
kind: input_tensor
1736-
type_var: T1
1737-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
1738-
- !LinalgOperandDefConfig
1739-
name: B
1740-
kind: input_tensor
1741-
type_var: T2
1742-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
1743-
- !LinalgOperandDefConfig
1744-
name: C
1745-
kind: output_tensor
1746-
type_var: U
1747-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)>
1748-
indexing_maps: !LinalgIndexingMapsConfig
1749-
static_indexing_maps:
1750-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
1751-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
1752-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
1753-
iterator_types:
1754-
- reduction
1755-
- parallel
1756-
- parallel
1757-
- reduction
1758-
assignments:
1759-
- !ScalarAssign
1760-
arg: C
1761-
value: !ScalarExpression
1762-
scalar_fn:
1763-
kind: binary
1764-
fn_name: add
1765-
operands:
1766-
- !ScalarExpression
1767-
scalar_arg: C
1768-
- !ScalarExpression
1769-
scalar_fn:
1770-
kind: binary
1771-
fn_name: mul
1772-
operands:
1773-
- !ScalarExpression
1774-
scalar_fn:
1775-
kind: type
1776-
fn_name: cast_signed
1777-
type_var: U
1778-
operands:
1779-
- !ScalarExpression
1780-
scalar_arg: A
1781-
- !ScalarExpression
1782-
scalar_fn:
1783-
kind: type
1784-
fn_name: cast_signed
1785-
type_var: U
1786-
operands:
1787-
- !ScalarExpression
1788-
scalar_arg: B
1789-
--- !LinalgOpConfig
17901720
metadata: !LinalgOpMetadata
17911721
name: matvec
17921722
cpp_class_name: MatvecOp

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 166 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -690,34 +690,32 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
690690

691691
Example Transpose:
692692
```mlir
693-
linalg.matmul indexing_maps = [
694-
affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
695-
affine_map<(d0, d1, d2) -> (d2, d1)>,
696-
affine_map<(d0, d1, d2) -> (d0, d1)>
697-
]
698-
ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
699-
outs(%arg2: memref<3x7xf32>)
693+
linalg.matmul
694+
indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose
695+
affine_map<(m, n, k) -> (k, n)>,
696+
affine_map<(m, n, k) -> (m, n)>]
697+
ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
698+
outs(%arg2: memref<3x7xf32>)
700699
```
701700

702701
Example Broadcast:
703-
```mlir
704-
linalg.matmul indexing_maps = [
705-
affine_map<(d0, d1, d2) -> (d2)>, // broadcast
706-
affine_map<(d0, d1, d2) -> (d2, d1)>,
707-
affine_map<(d0, d1, d2) -> (d0, d1)>
708-
]
709-
ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
710-
outs(%arg2: memref<3x7xf32>)
702+
```mlir
703+
linalg.matmul
704+
indexing_maps = [affine_map<(m, n, k) -> (k)>, // broadcast
705+
affine_map<(m, n, k) -> (k, n)>,
706+
affine_map<(m, n, k) -> (m, n)>]
707+
ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
708+
outs(%arg2: memref<3x7xf32>)
711709
```
712710

713711
Example Broadcast and transpose:
714712
```mlir
715-
linalg.matmul indexing_maps = [
716-
affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
717-
affine_map<(d0, d1, d2) -> (d2)>, // broadcast
718-
affine_map<(d0, d1, d2) -> (d0, d1)>
719-
]
720-
ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
713+
linalg.matmul
714+
indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose
715+
affine_map<(m, n, k) -> (k)>, // broadcast
716+
affine_map<(m, n, k) -> (m, n)>]
717+
ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>)
718+
outs(%arg2: memref<3x7xf32>)
721719
```
722720
}];
723721

@@ -775,7 +773,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
775773
static void regionBuilder(ImplicitLocOpBuilder &b,
776774
Block &block, ArrayRef<NamedAttribute> attrs);
777775

778-
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
776+
/// Returns a list of AffineMap with the default matmul indexing charactristic.
779777
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
780778

781779
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
@@ -954,35 +952,32 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
954952

955953
Example Transpose:
956954
```mlir
957-
linalg.batch_matmul indexing_maps = [
958-
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
959-
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
960-
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
961-
]
962-
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
963-
outs(%arg2: memref<2x3x7xf32>)
955+
linalg.batch_matmul
956+
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose
957+
affine_map<(batch, m, n, k) -> (batch, k, n)>,
958+
affine_map<(batch, m, n, k) -> (batch, m, n)>]
959+
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
960+
outs(%arg2: memref<2x3x7xf32>)
964961
```
965962

966963
Example Broadcast:
967964
```mlir
968-
linalg.batch_matmul indexing_maps = [
969-
affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
970-
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
971-
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
972-
]
973-
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
974-
outs(%arg2: memref<2x3x7xf32>)
965+
linalg.batch_matmul
966+
indexing_maps = [affine_map<(batch, m, n, k) -> (k)>, // broadcast
967+
affine_map<(batch, m, n, k) -> (batch, k, n)>,
968+
affine_map<(batch, m, n, k) -> (batch, m, n)>]
969+
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
970+
outs(%arg2: memref<2x3x7xf32>)
975971
```
976972

977973
Example Broadcast and Transpose:
978974
```mlir
979-
linalg.batch_matmul indexing_maps = [
980-
affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
981-
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
982-
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
983-
]
984-
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
985-
outs(%arg2: memref<2x3x7xf32>)
975+
linalg.batch_matmul
976+
indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>, // broadcast
977+
affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose
978+
affine_map<(batch, m, n, k) -> (batch, m, n)>]
979+
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
980+
outs(%arg2: memref<2x3x7xf32>)
986981
```
987982
}];
988983

@@ -1065,6 +1060,134 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
10651060
}
10661061

10671062

1063+
//===----------------------------------------------------------------------===//
1064+
// Op definition for BatchReduceMatmulOp
1065+
//===----------------------------------------------------------------------===//
1066+
1067+
def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
1068+
AttrSizedOperandSegments,
1069+
LinalgContractionOpInterface]> {
1070+
1071+
let summary = [{Performs a batch-reduce matrix multiplication on two inputs.
1072+
The partial multiplication results are reduced into a 2D output.}];
1073+
let description = [{
1074+
Numeric casting is performed on the operands to the inner multiply,
1075+
promoting them to the same data type as the accumulator/output.
1076+
1077+
Broadcast and Transpose semantics can be applied by specifying the explicit attribute
1078+
'indexing_maps' as shown below. This is a list attribute, so must include maps for all
1079+
arguments if specified.
1080+
1081+
Example Transpose:
1082+
```mlir
1083+
linalg.batch_reduce_matmul
1084+
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose
1085+
affine_map<(batch, m, n, k) -> (batch, k, n)>,
1086+
affine_map<(batch, m, n, k) -> (m, n)>]
1087+
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
1088+
outs(%arg2: memref<3x7xf32>)
1089+
```
1090+
1091+
Example Broadcast:
1092+
```mlir
1093+
linalg.batch_reduce_matmul
1094+
indexing_maps = [affine_map<(batch, m, n, k) -> (k)>, // broadcast
1095+
affine_map<(batch, m, n, k) -> (batch, k, n)>,
1096+
affine_map<(batch, m, n, k) -> (m, n)>]
1097+
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
1098+
outs(%arg2: memref<3x7xf32>)
1099+
```
1100+
1101+
Example Broadcast and Transpose:
1102+
```mlir
1103+
linalg.batch_reduce_matmul
1104+
indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>, // broadcast
1105+
affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose
1106+
affine_map<(batch, m, n, k) -> (m, n)>]
1107+
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
1108+
outs(%arg2: memref<3x7xf32>)
1109+
```
1110+
}];
1111+
1112+
let arguments = (ins
1113+
Variadic<AnyType>:$inputs,
1114+
Variadic<AnyShaped>:$outputs,
1115+
DefaultValuedOptionalAttr<
1116+
AffineMapArrayAttr,
1117+
"BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
1118+
>:$indexing_maps,
1119+
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
1120+
);
1121+
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
1122+
let regions = (region AnyRegion:$region);
1123+
1124+
let skipDefaultBuilders = 1;
1125+
let builders = [
1126+
OpBuilder<
1127+
(ins "ValueRange":$inputs, "ValueRange":$outputs,
1128+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
1129+
[{
1130+
buildBatchReduceMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
1131+
attributes, BatchReduceMatmulOp::getRegionBuilder(),
1132+
BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
1133+
}]>,
1134+
OpBuilder<
1135+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1136+
"ValueRange":$outputs,
1137+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
1138+
[{
1139+
buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes,
1140+
inputs, outputs, attributes, BatchReduceMatmulOp::getRegionBuilder(),
1141+
BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
1142+
}]>,
1143+
OpBuilder<
1144+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1145+
"ValueRange":$outputs,
1146+
"Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
1147+
[{
1148+
$_state.addAttribute("cast", cast);
1149+
buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
1150+
attributes, BatchReduceMatmulOp::getRegionBuilder(),
1151+
BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
1152+
}]>
1153+
1154+
];
1155+
let hasCustomAssemblyFormat = 1;
1156+
let hasFolder = 1;
1157+
let hasVerifier = 1;
1158+
1159+
let extraClassDeclaration = structuredOpsBaseDecls # [{
1160+
SmallVector<utils::IteratorType> getIteratorTypesArray();
1161+
1162+
/// Implements the block region builder.
1163+
static void regionBuilder(ImplicitLocOpBuilder &b,
1164+
Block &block, ArrayRef<NamedAttribute> attrs);
1165+
1166+
/// Returns a list of AffineMap with the default batch_reduce_matmul indexing charactristic.
1167+
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
1168+
1169+
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
1170+
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
1171+
1172+
static std::function<void(ImplicitLocOpBuilder &,
1173+
Block &, ArrayRef<NamedAttribute>)>
1174+
getRegionBuilder() {
1175+
return regionBuilder;
1176+
}
1177+
1178+
::mlir::MutableOperandRange getDpsInitsMutable() {
1179+
return getOutputsMutable();
1180+
}
1181+
1182+
// Generic methods.
1183+
static unsigned getNumRegionArgs();
1184+
std::string getLibraryCallName();
1185+
bool hasDynamicIndexingMaps() { return true; };
1186+
/// Returns true if the user defined indexing maps are not equal to default maps.
1187+
bool hasUserDefinedMaps();
1188+
}];
1189+
}
1190+
10681191
//===----------------------------------------------------------------------===//
10691192
// Named Linalg ops, implemented as a declarative configurations of generic ops.
10701193
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)