20
20
#include " mlir/Dialect/MemRef/Transforms/Transforms.h"
21
21
#include " mlir/Dialect/SCF/IR/SCF.h"
22
22
#include " mlir/Dialect/Tensor/IR/Tensor.h"
23
+ #include " mlir/IR/BuiltinTypeInterfaces.h"
24
+ #include " mlir/IR/BuiltinTypes.h"
25
+ #include " mlir/IR/Value.h"
23
26
#include " mlir/Interfaces/InferTypeOpInterface.h"
24
27
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
28
+ #include " llvm/Support/ErrorHandling.h"
29
+ #include " llvm/Support/InterleavedRange.h"
30
+
31
+ #define DEBUG_TYPE " resolve-shaped-type"
32
+ #define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE << " ]: " )
25
33
26
34
namespace mlir {
27
35
namespace memref {
28
36
#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS
29
37
#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS
38
+ #define GEN_PASS_DEF_INFERSTATICSHAPESPASS
30
39
#include " mlir/Dialect/MemRef/Transforms/Passes.h.inc"
31
40
} // namespace memref
32
41
} // namespace mlir
@@ -105,6 +114,99 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
105
114
}
106
115
};
107
116
117
+ struct ReifyToInferStaticShapePattern
118
+ : public OpInterfaceRewritePattern<ReifyRankedShapedTypeOpInterface> {
119
+ using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
120
+
121
+ LogicalResult matchAndRewrite (ReifyRankedShapedTypeOpInterface op,
122
+ PatternRewriter &rewriter) const override {
123
+ LLVM_DEBUG (
124
+ { DBGS () << " ReifyToInferStaticShapePattern on " << op << " \n " ; });
125
+
126
+ bool rewriteToMoreStatic = false ;
127
+ ReifiedRankedShapedTypeDims reifiedResultShapes;
128
+ if (failed (reifyResultShapes (rewriter, op, reifiedResultShapes)) ||
129
+ reifiedResultShapes.empty ()) {
130
+ LLVM_DEBUG ({ DBGS () << " reifyResultShapes failed\n " ; });
131
+ return failure ();
132
+ }
133
+
134
+ SmallVector<Type> newTypes;
135
+ for (auto [t, reifiedShape] :
136
+ llvm::zip (op->getResultTypes (), reifiedResultShapes)) {
137
+ ShapedType st = dyn_cast<ShapedType>(t);
138
+ if (!st)
139
+ continue ;
140
+
141
+ SmallVector<int64_t > newShape;
142
+ for (const auto &[s, ofr] :
143
+ llvm::zip_equal (st.getShape (), reifiedShape)) {
144
+ std::optional<int64_t > maybeCst = getConstantIntValue (ofr);
145
+ // Reification does not add static information, just use existing shape.
146
+ if (!maybeCst.has_value ()) {
147
+ newShape.push_back (s);
148
+ continue ;
149
+ }
150
+ int64_t cst = *maybeCst;
151
+ assert ((ShapedType::isDynamic (s) || s == cst) &&
152
+ " constants must agree!" );
153
+ newShape.push_back (cst);
154
+ }
155
+
156
+ if (newShape == st.getShape ()) {
157
+ newTypes.push_back (t);
158
+ continue ;
159
+ }
160
+
161
+ rewriteToMoreStatic = true ;
162
+ Type newType = st.cloneWith (newShape, st.getElementType ());
163
+ newTypes.push_back (newType);
164
+ }
165
+
166
+ LLVM_DEBUG ({
167
+ DBGS () << " --oldTypes: " << llvm::interleaved_array (op->getResultTypes ())
168
+ << " \n " ;
169
+ DBGS () << " --newTypes: " << llvm::interleaved_array (newTypes) << " \n " ;
170
+ });
171
+ if (!rewriteToMoreStatic) {
172
+ LLVM_DEBUG ({ DBGS () << " not more static\n " ; });
173
+ return failure ();
174
+ }
175
+
176
+ // We now have newTypes that need to be turned to tensor::CastOp.
177
+ Location loc = op->getLoc ();
178
+ SmallVector<Value> newResults;
179
+ Operation *newOp = rewriter.clone (*op);
180
+ for (auto [nt, oldVal] : llvm::zip (newTypes, op->getResults ())) {
181
+ Type ot = oldVal.getType ();
182
+ OpResult newResult = newOp->getResult (oldVal.getResultNumber ());
183
+ if (ot == nt) {
184
+ newResults.push_back (newResult);
185
+ continue ;
186
+ }
187
+ newResult.setType (nt);
188
+ if (isa<RankedTensorType>(nt)) {
189
+ newResults.push_back (
190
+ rewriter.create <tensor::CastOp>(loc, ot, newResult));
191
+ } else if (isa<MemRefType>(nt)) {
192
+ newResults.push_back (
193
+ rewriter.create <memref::CastOp>(loc, ot, newResult));
194
+ } else {
195
+ llvm_unreachable (" expected RankedTensorType or MemRefType" );
196
+ }
197
+ }
198
+
199
+ LLVM_DEBUG ({
200
+ op->getParentOp ()->dump ();
201
+ DBGS () << " replace op " << *op << " \n " ;
202
+ DBGS () << " with newResults " << llvm::interleaved_array (newResults)
203
+ << " \n\n\n\n " ;
204
+ });
205
+ rewriter.replaceAllOpUsesWith (op, newResults);
206
+ return success ();
207
+ }
208
+ };
209
+
108
210
// / Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
109
211
// /
110
212
// / ```
@@ -175,6 +277,11 @@ struct ResolveShapedTypeResultDimsPass final
175
277
void runOnOperation () override ;
176
278
};
177
279
280
+ struct InferStaticShapesPass final
281
+ : public memref::impl::InferStaticShapesPassBase<InferStaticShapesPass> {
282
+ void runOnOperation () override ;
283
+ };
284
+
178
285
} // namespace
179
286
180
287
void memref::populateResolveRankedShapedTypeResultDimsPatterns (
@@ -192,6 +299,11 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
192
299
patterns.getContext ());
193
300
}
194
301
302
+ void memref::populateReifyToInferStaticShapePatterns (
303
+ RewritePatternSet &patterns) {
304
+ patterns.add <ReifyToInferStaticShapePattern>(patterns.getContext ());
305
+ }
306
+
195
307
void ResolveRankedShapeTypeResultDimsPass::runOnOperation () {
196
308
RewritePatternSet patterns (&getContext ());
197
309
memref::populateResolveRankedShapedTypeResultDimsPatterns (patterns);
@@ -206,3 +318,17 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
206
318
if (failed (applyPatternsGreedily (getOperation (), std::move (patterns))))
207
319
return signalPassFailure ();
208
320
}
321
+
322
+ void InferStaticShapesPass::runOnOperation () {
323
+ RewritePatternSet patterns (&getContext ());
324
+ patterns.add <ReifyToInferStaticShapePattern>(&getContext ());
325
+ FrozenRewritePatternSet frozenPatterns (std::move (patterns));
326
+
327
+ SmallVector<Operation *> opsToSimplify;
328
+ getOperation ()->walk ([&](ReifyRankedShapedTypeOpInterface op) {
329
+ opsToSimplify.push_back (op);
330
+ });
331
+ (void )applyOpPatternsGreedily (opsToSimplify, frozenPatterns,
332
+ GreedyRewriteConfig ().setStrictness (
333
+ GreedyRewriteStrictness::ExistingOps));
334
+ }
0 commit comments