Skip to content

Commit 1b8b556

Browse files
[mlir][Vector] Add fastmath flags to vector.reduction (#66905)
This revision pipes the fastmath attribute support through the vector.reduction op. This seemingly simple first step already requires quite some genuflexions, file and builder reorganization. In the process, retire the boolean reassoc flag deep in the LLVM dialect builders and just use the fastmath attribute. During conversions, templated builders for predicated intrinsics are partially cleaned up. In the future, to finalize the cleanups, one should consider adding fastmath to the VPIntrinsic ops.
1 parent ebefe83 commit 1b8b556

File tree

15 files changed

+322
-226
lines changed

15 files changed

+322
-226
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td

+8-26
Original file line numberDiff line numberDiff line change
@@ -654,32 +654,14 @@ class LLVM_VecReductionI<string mnem>
654654
// LLVM vector reduction over a single vector, with an initial value,
655655
// and with permission to reassociate the reduction operations.
656656
class LLVM_VecReductionAccBase<string mnem, Type element>
657-
: LLVM_OneResultIntrOp<"vector.reduce." # mnem, [], [0],
658-
[Pure, SameOperandsAndResultElementType]>,
659-
Arguments<(ins element:$start_value, LLVM_VectorOf<element>:$input,
660-
DefaultValuedAttr<BoolAttr, "false">:$reassoc)> {
661-
let llvmBuilder = [{
662-
llvm::Module *module = builder.GetInsertBlock()->getModule();
663-
llvm::Function *fn = llvm::Intrinsic::getDeclaration(
664-
module,
665-
llvm::Intrinsic::vector_reduce_}] # mnem # [{,
666-
{ }] # !interleave(ListIntSubst<LLVM_IntrPatterns.operand, [1]>.lst,
667-
", ") # [{
668-
});
669-
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
670-
llvm::FastMathFlags origFM = builder.getFastMathFlags();
671-
llvm::FastMathFlags tempFM = origFM;
672-
tempFM.setAllowReassoc($reassoc);
673-
builder.setFastMathFlags(tempFM); // set fastmath flag
674-
$res = builder.CreateCall(fn, operands);
675-
builder.setFastMathFlags(origFM); // restore fastmath flag
676-
}];
677-
let mlirBuilder = [{
678-
bool allowReassoc = inst->getFastMathFlags().allowReassoc();
679-
$res = $_builder.create<$_qualCppClassName>($_location,
680-
$_resultType, $start_value, $input, allowReassoc);
681-
}];
682-
}
657+
: LLVM_OneResultIntrOp</*mnem=*/"vector.reduce." # mnem,
658+
/*overloadedResults=*/[],
659+
/*overloadedOperands=*/[1],
660+
/*traits=*/[Pure, SameOperandsAndResultElementType],
661+
/*equiresFastmath=*/1>,
662+
Arguments<(ins element:$start_value,
663+
LLVM_VectorOf<element>:$input,
664+
DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags)>;
683665

684666
class LLVM_VecReductionAccF<string mnem>
685667
: LLVM_VecReductionAccBase<mnem, AnyFloat>;
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
add_mlir_dialect(VectorOps vector)
2-
add_mlir_doc(VectorOps VectorOps Dialects/ -gen-op-doc)
1+
add_mlir_dialect(Vector vector)
2+
add_mlir_doc(Vector Vector Dialects/ -gen-op-doc -dialect=vector)
33

4+
# Add Vector operations
45
set(LLVM_TARGET_DEFINITIONS VectorOps.td)
5-
mlir_tablegen(VectorOpsEnums.h.inc -gen-enum-decls)
6-
mlir_tablegen(VectorOpsEnums.cpp.inc -gen-enum-defs)
7-
mlir_tablegen(VectorOpsAttrDefs.h.inc -gen-attrdef-decls)
8-
mlir_tablegen(VectorOpsAttrDefs.cpp.inc -gen-attrdef-defs)
9-
add_public_tablegen_target(MLIRVectorOpsEnumsIncGen)
10-
add_dependencies(mlir-headers MLIRVectorOpsEnumsIncGen)
6+
mlir_tablegen(VectorOps.h.inc -gen-op-decls)
7+
mlir_tablegen(VectorOps.cpp.inc -gen-op-defs)
8+
add_public_tablegen_target(MLIRVectorOpsIncGen)
9+
add_dependencies(mlir-generic-headers MLIRVectorOpsIncGen)
10+
11+
# Add Vector attributes
12+
set(LLVM_TARGET_DEFINITIONS VectorAttributes.td)
13+
mlir_tablegen(VectorEnums.h.inc -gen-enum-decls)
14+
mlir_tablegen(VectorEnums.cpp.inc -gen-enum-defs)
15+
mlir_tablegen(VectorAttributes.h.inc -gen-attrdef-decls)
16+
mlir_tablegen(VectorAttributes.cpp.inc -gen-attrdef-defs)
17+
add_public_tablegen_target(MLIRVectorAttributesIncGen)
18+
add_dependencies(mlir-generic-headers MLIRVectorAttributesIncGen)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===- Vector.td - Vector Dialect --------------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares the Vector dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR
14+
#define MLIR_DIALECT_VECTOR_IR_VECTOR
15+
16+
include "mlir/IR/OpBase.td"
17+
18+
def Vector_Dialect : Dialect {
19+
let name = "vector";
20+
let cppNamespace = "::mlir::vector";
21+
22+
let useDefaultAttributePrinterParser = 1;
23+
let hasConstantMaterializer = 1;
24+
let dependentDialects = ["arith::ArithDialect"];
25+
}
26+
27+
// Base class for Vector dialect ops.
28+
class Vector_Op<string mnemonic, list<Trait> traits = []> :
29+
Op<Vector_Dialect, mnemonic, traits>;
30+
31+
#endif // MLIR_DIALECT_VECTOR_IR_VECTOR
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
//===- VectorAttributes.td - Vector Dialect ----------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares the attributes used in the Vector dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES
14+
#define MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES
15+
16+
include "mlir/Dialect/Vector/IR/Vector.td"
17+
include "mlir/IR/EnumAttr.td"
18+
19+
// The "kind" of combining function for contractions and reductions.
20+
def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
21+
def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
22+
def COMBINING_KIND_MINUI : I32BitEnumAttrCaseBit<"MINUI", 2, "minui">;
23+
def COMBINING_KIND_MINSI : I32BitEnumAttrCaseBit<"MINSI", 3, "minsi">;
24+
def COMBINING_KIND_MINF : I32BitEnumAttrCaseBit<"MINF", 4, "minf">;
25+
def COMBINING_KIND_MAXUI : I32BitEnumAttrCaseBit<"MAXUI", 5, "maxui">;
26+
def COMBINING_KIND_MAXSI : I32BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">;
27+
def COMBINING_KIND_MAXF : I32BitEnumAttrCaseBit<"MAXF", 7, "maxf">;
28+
def COMBINING_KIND_AND : I32BitEnumAttrCaseBit<"AND", 8, "and">;
29+
def COMBINING_KIND_OR : I32BitEnumAttrCaseBit<"OR", 9, "or">;
30+
def COMBINING_KIND_XOR : I32BitEnumAttrCaseBit<"XOR", 10, "xor">;
31+
def COMBINING_KIND_MINIMUMF : I32BitEnumAttrCaseBit<"MINIMUMF", 11, "minimumf">;
32+
def COMBINING_KIND_MAXIMUMF : I32BitEnumAttrCaseBit<"MAXIMUMF", 12, "maximumf">;
33+
34+
def CombiningKind : I32BitEnumAttr<
35+
"CombiningKind",
36+
"Kind of combining function for contractions and reductions",
37+
[COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI,
38+
COMBINING_KIND_MINSI, COMBINING_KIND_MINF, COMBINING_KIND_MAXUI,
39+
COMBINING_KIND_MAXSI, COMBINING_KIND_MAXF, COMBINING_KIND_AND,
40+
COMBINING_KIND_OR, COMBINING_KIND_XOR,
41+
COMBINING_KIND_MAXIMUMF, COMBINING_KIND_MINIMUMF]> {
42+
let cppNamespace = "::mlir::vector";
43+
let genSpecializedAttr = 0;
44+
}
45+
46+
/// An attribute that specifies the combining function for `vector.contract`,
47+
/// and `vector.reduction`.
48+
def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
49+
let assemblyFormat = "`<` $value `>`";
50+
}
51+
52+
def Vector_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
53+
I32EnumAttrCase<"parallel", 0>,
54+
I32EnumAttrCase<"reduction", 1>
55+
]> {
56+
let genSpecializedAttr = 0;
57+
let cppNamespace = "::mlir::vector";
58+
}
59+
60+
def Vector_IteratorTypeEnum
61+
: EnumAttr<Vector_Dialect, Vector_IteratorType, "iterator_type"> {
62+
let assemblyFormat = "`<` $value `>`";
63+
}
64+
65+
def Vector_IteratorTypeArrayAttr
66+
: TypedArrayAttrBase<Vector_IteratorTypeEnum,
67+
"Iterator type should be an enum.">;
68+
69+
def PrintPunctuation : I32EnumAttr<"PrintPunctuation",
70+
"Punctuation for separating vectors or vector elements", [
71+
I32EnumAttrCase<"NoPunctuation", 0, "no_punctuation">,
72+
I32EnumAttrCase<"NewLine", 1, "newline">,
73+
I32EnumAttrCase<"Comma", 2, "comma">,
74+
I32EnumAttrCase<"Open", 3, "open">,
75+
I32EnumAttrCase<"Close", 4, "close">
76+
]> {
77+
let cppNamespace = "::mlir::vector";
78+
let genSpecializedAttr = 0;
79+
}
80+
81+
def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctuation"> {
82+
let assemblyFormat = "`<` $value `>`";
83+
}
84+
85+
#endif // MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
1515

1616
#include "mlir/Bytecode/BytecodeOpInterface.h"
17+
#include "mlir/Dialect/Arith/IR/Arith.h"
1718
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
1819
#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h"
1920
#include "mlir/IR/AffineMap.h"
@@ -31,10 +32,10 @@
3132
#include "llvm/ADT/StringExtras.h"
3233

3334
// Pull in all enum type definitions and utility function declarations.
34-
#include "mlir/Dialect/Vector/IR/VectorOpsEnums.h.inc"
35+
#include "mlir/Dialect/Vector/IR/VectorEnums.h.inc"
3536

3637
#define GET_ATTRDEF_CLASSES
37-
#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.h.inc"
38+
#include "mlir/Dialect/Vector/IR/VectorAttributes.h.inc"
3839

3940
namespace mlir {
4041
class MLIRContext;
@@ -157,7 +158,7 @@ Value selectPassthru(OpBuilder &builder, Value mask, Value newValue,
157158
} // namespace mlir
158159

159160
#define GET_OP_CLASSES
161+
#include "mlir/Dialect/Vector/IR/VectorDialect.h.inc"
160162
#include "mlir/Dialect/Vector/IR/VectorOps.h.inc"
161-
#include "mlir/Dialect/Vector/IR/VectorOpsDialect.h.inc"
162163

163164
#endif // MLIR_DIALECT_VECTOR_IR_VECTOROPS_H

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

+20-87
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13-
#ifndef VECTOR_OPS
14-
#define VECTOR_OPS
13+
#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
14+
#define MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
1515

16+
include "mlir/Dialect/Vector/IR/Vector.td"
17+
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
18+
include "mlir/Dialect/Arith/IR/ArithBase.td"
19+
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
1620
include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
1721
include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
1822
include "mlir/IR/EnumAttr.td"
@@ -23,69 +27,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
2327
include "mlir/Interfaces/VectorInterfaces.td"
2428
include "mlir/Interfaces/ViewLikeInterface.td"
2529

26-
def Vector_Dialect : Dialect {
27-
let name = "vector";
28-
let cppNamespace = "::mlir::vector";
29-
30-
let useDefaultAttributePrinterParser = 1;
31-
let hasConstantMaterializer = 1;
32-
let dependentDialects = ["arith::ArithDialect"];
33-
}
34-
35-
// Base class for Vector dialect ops.
36-
class Vector_Op<string mnemonic, list<Trait> traits = []> :
37-
Op<Vector_Dialect, mnemonic, traits>;
38-
39-
// The "kind" of combining function for contractions and reductions.
40-
def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
41-
def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
42-
def COMBINING_KIND_MINUI : I32BitEnumAttrCaseBit<"MINUI", 2, "minui">;
43-
def COMBINING_KIND_MINSI : I32BitEnumAttrCaseBit<"MINSI", 3, "minsi">;
44-
def COMBINING_KIND_MINF : I32BitEnumAttrCaseBit<"MINF", 4, "minf">;
45-
def COMBINING_KIND_MAXUI : I32BitEnumAttrCaseBit<"MAXUI", 5, "maxui">;
46-
def COMBINING_KIND_MAXSI : I32BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">;
47-
def COMBINING_KIND_MAXF : I32BitEnumAttrCaseBit<"MAXF", 7, "maxf">;
48-
def COMBINING_KIND_AND : I32BitEnumAttrCaseBit<"AND", 8, "and">;
49-
def COMBINING_KIND_OR : I32BitEnumAttrCaseBit<"OR", 9, "or">;
50-
def COMBINING_KIND_XOR : I32BitEnumAttrCaseBit<"XOR", 10, "xor">;
51-
def COMBINING_KIND_MINIMUMF : I32BitEnumAttrCaseBit<"MINIMUMF", 11, "minimumf">;
52-
def COMBINING_KIND_MAXIMUMF : I32BitEnumAttrCaseBit<"MAXIMUMF", 12, "maximumf">;
53-
54-
def CombiningKind : I32BitEnumAttr<
55-
"CombiningKind",
56-
"Kind of combining function for contractions and reductions",
57-
[COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI,
58-
COMBINING_KIND_MINSI, COMBINING_KIND_MINF, COMBINING_KIND_MAXUI,
59-
COMBINING_KIND_MAXSI, COMBINING_KIND_MAXF, COMBINING_KIND_AND,
60-
COMBINING_KIND_OR, COMBINING_KIND_XOR,
61-
COMBINING_KIND_MAXIMUMF, COMBINING_KIND_MINIMUMF]> {
62-
let cppNamespace = "::mlir::vector";
63-
let genSpecializedAttr = 0;
64-
}
65-
66-
/// An attribute that specifies the combining function for `vector.contract`,
67-
/// and `vector.reduction`.
68-
def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
69-
let assemblyFormat = "`<` $value `>`";
70-
}
71-
72-
def Vector_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
73-
I32EnumAttrCase<"parallel", 0>,
74-
I32EnumAttrCase<"reduction", 1>
75-
]> {
76-
let genSpecializedAttr = 0;
77-
let cppNamespace = "::mlir::vector";
78-
}
79-
80-
def Vector_IteratorTypeEnum
81-
: EnumAttr<Vector_Dialect, Vector_IteratorType, "iterator_type"> {
82-
let assemblyFormat = "`<` $value `>`";
83-
}
84-
85-
def Vector_IteratorTypeArrayAttr
86-
: TypedArrayAttrBase<Vector_IteratorTypeEnum,
87-
"Iterator type should be an enum.">;
88-
8930
// TODO: Add an attribute to specify a different algebra with operators other
9031
// than the current set: {*, +}.
9132
def Vector_ContractionOp :
@@ -274,12 +215,16 @@ def Vector_ReductionOp :
274215
Vector_Op<"reduction", [Pure,
275216
PredOpTrait<"source operand and result have same element type",
276217
TCresVTEtIsSameAsOpBase<0, 0>>,
218+
DeclareOpInterfaceMethods<ArithFastMathInterface>,
277219
DeclareOpInterfaceMethods<MaskableOpInterface>,
278-
DeclareOpInterfaceMethods<VectorUnrollOpInterface,
279-
["getShapeForUnroll"]>]>,
220+
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
221+
]>,
280222
Arguments<(ins Vector_CombiningKindAttr:$kind,
281223
AnyVectorOfAnyRank:$vector,
282-
Optional<AnyType>:$acc)>,
224+
Optional<AnyType>:$acc,
225+
DefaultValuedAttr<
226+
Arith_FastMathAttr,
227+
"::mlir::arith::FastMathFlags::none">:$fastmath)>,
283228
Results<(outs AnyType:$dest)> {
284229
let summary = "reduction operation";
285230
let description = [{
@@ -309,9 +254,13 @@ def Vector_ReductionOp :
309254
}];
310255
let builders = [
311256
// Builder that infers the type of `dest`.
312-
OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, "Value":$acc)>,
257+
OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, "Value":$acc,
258+
CArg<"::mlir::arith::FastMathFlags",
259+
"::mlir::arith::FastMathFlags::none">:$fastMathFlags)>,
313260
// Builder that infers the type of `dest` and has no accumulator.
314-
OpBuilder<(ins "CombiningKind":$kind, "Value":$vector)>
261+
OpBuilder<(ins "CombiningKind":$kind, "Value":$vector,
262+
CArg<"::mlir::arith::FastMathFlags",
263+
"::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
315264
];
316265

317266
// TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional
@@ -2469,22 +2418,6 @@ def Vector_TransposeOp :
24692418
let hasVerifier = 1;
24702419
}
24712420

2472-
def PrintPunctuation : I32EnumAttr<"PrintPunctuation",
2473-
"Punctuation for separating vectors or vector elements", [
2474-
I32EnumAttrCase<"NoPunctuation", 0, "no_punctuation">,
2475-
I32EnumAttrCase<"NewLine", 1, "newline">,
2476-
I32EnumAttrCase<"Comma", 2, "comma">,
2477-
I32EnumAttrCase<"Open", 3, "open">,
2478-
I32EnumAttrCase<"Close", 4, "close">
2479-
]> {
2480-
let cppNamespace = "::mlir::vector";
2481-
let genSpecializedAttr = 0;
2482-
}
2483-
2484-
def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctuation"> {
2485-
let assemblyFormat = "`<` $value `>`";
2486-
}
2487-
24882421
def Vector_PrintOp :
24892422
Vector_Op<"print", []>,
24902423
Arguments<(ins Optional<Type<Or<[
@@ -2939,4 +2872,4 @@ def Vector_WarpExecuteOnLane0Op : Vector_Op<"warp_execute_on_lane_0",
29392872
}];
29402873
}
29412874

2942-
#endif // VECTOR_OPS
2875+
#endif // MLIR_DIALECT_VECTOR_IR_VECTOR_OPS

0 commit comments

Comments
 (0)