10
10
//
11
11
//===----------------------------------------------------------------------===//
12
12
13
- #ifndef VECTOR_OPS
14
- #define VECTOR_OPS
13
+ #ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
14
+ #define MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
15
15
16
+ include "mlir/Dialect/Vector/IR/Vector.td"
17
+ include "mlir/Dialect/Vector/IR/VectorAttributes.td"
18
+ include "mlir/Dialect/Arith/IR/ArithBase.td"
19
+ include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
16
20
include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
17
21
include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
18
22
include "mlir/IR/EnumAttr.td"
@@ -23,69 +27,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
23
27
include "mlir/Interfaces/VectorInterfaces.td"
24
28
include "mlir/Interfaces/ViewLikeInterface.td"
25
29
26
- def Vector_Dialect : Dialect {
27
- let name = "vector";
28
- let cppNamespace = "::mlir::vector";
29
-
30
- let useDefaultAttributePrinterParser = 1;
31
- let hasConstantMaterializer = 1;
32
- let dependentDialects = ["arith::ArithDialect"];
33
- }
34
-
35
- // Base class for Vector dialect ops.
36
- class Vector_Op<string mnemonic, list<Trait> traits = []> :
37
- Op<Vector_Dialect, mnemonic, traits>;
38
-
39
- // The "kind" of combining function for contractions and reductions.
40
- def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
41
- def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
42
- def COMBINING_KIND_MINUI : I32BitEnumAttrCaseBit<"MINUI", 2, "minui">;
43
- def COMBINING_KIND_MINSI : I32BitEnumAttrCaseBit<"MINSI", 3, "minsi">;
44
- def COMBINING_KIND_MINF : I32BitEnumAttrCaseBit<"MINF", 4, "minf">;
45
- def COMBINING_KIND_MAXUI : I32BitEnumAttrCaseBit<"MAXUI", 5, "maxui">;
46
- def COMBINING_KIND_MAXSI : I32BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">;
47
- def COMBINING_KIND_MAXF : I32BitEnumAttrCaseBit<"MAXF", 7, "maxf">;
48
- def COMBINING_KIND_AND : I32BitEnumAttrCaseBit<"AND", 8, "and">;
49
- def COMBINING_KIND_OR : I32BitEnumAttrCaseBit<"OR", 9, "or">;
50
- def COMBINING_KIND_XOR : I32BitEnumAttrCaseBit<"XOR", 10, "xor">;
51
- def COMBINING_KIND_MINIMUMF : I32BitEnumAttrCaseBit<"MINIMUMF", 11, "minimumf">;
52
- def COMBINING_KIND_MAXIMUMF : I32BitEnumAttrCaseBit<"MAXIMUMF", 12, "maximumf">;
53
-
54
- def CombiningKind : I32BitEnumAttr<
55
- "CombiningKind",
56
- "Kind of combining function for contractions and reductions",
57
- [COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI,
58
- COMBINING_KIND_MINSI, COMBINING_KIND_MINF, COMBINING_KIND_MAXUI,
59
- COMBINING_KIND_MAXSI, COMBINING_KIND_MAXF, COMBINING_KIND_AND,
60
- COMBINING_KIND_OR, COMBINING_KIND_XOR,
61
- COMBINING_KIND_MAXIMUMF, COMBINING_KIND_MINIMUMF]> {
62
- let cppNamespace = "::mlir::vector";
63
- let genSpecializedAttr = 0;
64
- }
65
-
66
- /// An attribute that specifies the combining function for `vector.contract`,
67
- /// and `vector.reduction`.
68
- def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
69
- let assemblyFormat = "`<` $value `>`";
70
- }
71
-
72
- def Vector_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
73
- I32EnumAttrCase<"parallel", 0>,
74
- I32EnumAttrCase<"reduction", 1>
75
- ]> {
76
- let genSpecializedAttr = 0;
77
- let cppNamespace = "::mlir::vector";
78
- }
79
-
80
- def Vector_IteratorTypeEnum
81
- : EnumAttr<Vector_Dialect, Vector_IteratorType, "iterator_type"> {
82
- let assemblyFormat = "`<` $value `>`";
83
- }
84
-
85
- def Vector_IteratorTypeArrayAttr
86
- : TypedArrayAttrBase<Vector_IteratorTypeEnum,
87
- "Iterator type should be an enum.">;
88
-
89
30
// TODO: Add an attribute to specify a different algebra with operators other
90
31
// than the current set: {*, +}.
91
32
def Vector_ContractionOp :
@@ -274,12 +215,16 @@ def Vector_ReductionOp :
274
215
Vector_Op<"reduction", [Pure,
275
216
PredOpTrait<"source operand and result have same element type",
276
217
TCresVTEtIsSameAsOpBase<0, 0>>,
218
+ DeclareOpInterfaceMethods<ArithFastMathInterface>,
277
219
DeclareOpInterfaceMethods<MaskableOpInterface>,
278
- DeclareOpInterfaceMethods<VectorUnrollOpInterface,
279
- ["getShapeForUnroll"]> ]>,
220
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
221
+ ]>,
280
222
Arguments<(ins Vector_CombiningKindAttr:$kind,
281
223
AnyVectorOfAnyRank:$vector,
282
- Optional<AnyType>:$acc)>,
224
+ Optional<AnyType>:$acc,
225
+ DefaultValuedAttr<
226
+ Arith_FastMathAttr,
227
+ "::mlir::arith::FastMathFlags::none">:$fastmath)>,
283
228
Results<(outs AnyType:$dest)> {
284
229
let summary = "reduction operation";
285
230
let description = [{
@@ -309,9 +254,13 @@ def Vector_ReductionOp :
309
254
}];
310
255
let builders = [
311
256
// Builder that infers the type of `dest`.
312
- OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, "Value":$acc)>,
257
+ OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, "Value":$acc,
258
+ CArg<"::mlir::arith::FastMathFlags",
259
+ "::mlir::arith::FastMathFlags::none">:$fastMathFlags)>,
313
260
// Builder that infers the type of `dest` and has no accumulator.
314
- OpBuilder<(ins "CombiningKind":$kind, "Value":$vector)>
261
+ OpBuilder<(ins "CombiningKind":$kind, "Value":$vector,
262
+ CArg<"::mlir::arith::FastMathFlags",
263
+ "::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
315
264
];
316
265
317
266
// TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional
@@ -2469,22 +2418,6 @@ def Vector_TransposeOp :
2469
2418
let hasVerifier = 1;
2470
2419
}
2471
2420
2472
- def PrintPunctuation : I32EnumAttr<"PrintPunctuation",
2473
- "Punctuation for separating vectors or vector elements", [
2474
- I32EnumAttrCase<"NoPunctuation", 0, "no_punctuation">,
2475
- I32EnumAttrCase<"NewLine", 1, "newline">,
2476
- I32EnumAttrCase<"Comma", 2, "comma">,
2477
- I32EnumAttrCase<"Open", 3, "open">,
2478
- I32EnumAttrCase<"Close", 4, "close">
2479
- ]> {
2480
- let cppNamespace = "::mlir::vector";
2481
- let genSpecializedAttr = 0;
2482
- }
2483
-
2484
- def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctuation"> {
2485
- let assemblyFormat = "`<` $value `>`";
2486
- }
2487
-
2488
2421
def Vector_PrintOp :
2489
2422
Vector_Op<"print", []>,
2490
2423
Arguments<(ins Optional<Type<Or<[
@@ -2939,4 +2872,4 @@ def Vector_WarpExecuteOnLane0Op : Vector_Op<"warp_execute_on_lane_0",
2939
2872
}];
2940
2873
}
2941
2874
2942
- #endif // VECTOR_OPS
2875
+ #endif // MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
0 commit comments