Skip to content

Commit 465c660

Browse files
[mlir][memref] Add a new InderStaticShapes pass for ReifyRankedShapedTypeOpInterface
1 parent d83457e commit 465c660

File tree

4 files changed

+161
-0
lines changed

4 files changed

+161
-0
lines changed

mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,19 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
182182
];
183183
}
184184

185+
def InferStaticShapesPass : Pass<"infer-static-shapes"> {
186+
let summary = "Resolve memref.dim of result values";
187+
let description = [{
188+
The pass resolves memref.dim of result of operations that
189+
implement the `InferShapedTypeOpInterface` or
190+
`ReifyRankedShapedTypeOpInterface` in terms of shapes of its
191+
operands.
192+
}];
193+
let dependentDialects = [
194+
"affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
195+
];
196+
}
197+
185198
def ExpandStridedMetadataPass : Pass<"expand-strided-metadata"> {
186199
let summary = "Expand memref operations into easier to analyze constructs";
187200
let description = [{

mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ void populateResolveRankedShapedTypeResultDimsPatterns(
5757
/// terms of shapes of its input operands.
5858
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
5959

60+
/// Appends patterns that allow making ReifyRankedShapedTypeOpInterface ops
61+
/// shapes more static.
62+
void populateReifyToInferStaticShapePatterns(RewritePatternSet &patterns);
63+
6064
/// Appends patterns for expanding memref operations that modify the metadata
6165
/// (sizes, offset, strides) of a memref into easier to analyze constructs.
6266
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);

mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,22 @@
2020
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
2121
#include "mlir/Dialect/SCF/IR/SCF.h"
2222
#include "mlir/Dialect/Tensor/IR/Tensor.h"
23+
#include "mlir/IR/BuiltinTypeInterfaces.h"
24+
#include "mlir/IR/BuiltinTypes.h"
25+
#include "mlir/IR/Value.h"
2326
#include "mlir/Interfaces/InferTypeOpInterface.h"
2427
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28+
#include "llvm/Support/ErrorHandling.h"
29+
#include "llvm/Support/InterleavedRange.h"
30+
31+
#define DEBUG_TYPE "resolve-shaped-type"
32+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
2533

2634
namespace mlir {
2735
namespace memref {
2836
#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS
2937
#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS
38+
#define GEN_PASS_DEF_INFERSTATICSHAPESPASS
3039
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
3140
} // namespace memref
3241
} // namespace mlir
@@ -105,6 +114,99 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
105114
}
106115
};
107116

117+
struct ReifyToInferStaticShapePattern
118+
: public OpInterfaceRewritePattern<ReifyRankedShapedTypeOpInterface> {
119+
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
120+
121+
LogicalResult matchAndRewrite(ReifyRankedShapedTypeOpInterface op,
122+
PatternRewriter &rewriter) const override {
123+
LLVM_DEBUG(
124+
{ DBGS() << "ReifyToInferStaticShapePattern on " << op << "\n"; });
125+
126+
bool rewriteToMoreStatic = false;
127+
ReifiedRankedShapedTypeDims reifiedResultShapes;
128+
if (failed(reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
129+
reifiedResultShapes.empty()) {
130+
LLVM_DEBUG({ DBGS() << "reifyResultShapes failed\n"; });
131+
return failure();
132+
}
133+
134+
SmallVector<Type> newTypes;
135+
for (auto [t, reifiedShape] :
136+
llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
137+
ShapedType st = dyn_cast<ShapedType>(t);
138+
if (!st)
139+
continue;
140+
141+
SmallVector<int64_t> newShape;
142+
for (const auto &[s, ofr] :
143+
llvm::zip_equal(st.getShape(), reifiedShape)) {
144+
std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
145+
// Reification does not add static information, just use existing shape.
146+
if (!maybeCst.has_value()) {
147+
newShape.push_back(s);
148+
continue;
149+
}
150+
int64_t cst = *maybeCst;
151+
assert((ShapedType::isDynamic(s) || s == cst) &&
152+
"constants must agree!");
153+
newShape.push_back(cst);
154+
}
155+
156+
if (newShape == st.getShape()) {
157+
newTypes.push_back(t);
158+
continue;
159+
}
160+
161+
rewriteToMoreStatic = true;
162+
Type newType = st.cloneWith(newShape, st.getElementType());
163+
newTypes.push_back(newType);
164+
}
165+
166+
LLVM_DEBUG({
167+
DBGS() << "--oldTypes: " << llvm::interleaved_array(op->getResultTypes())
168+
<< " \n";
169+
DBGS() << "--newTypes: " << llvm::interleaved_array(newTypes) << " \n";
170+
});
171+
if (!rewriteToMoreStatic) {
172+
LLVM_DEBUG({ DBGS() << "not more static\n"; });
173+
return failure();
174+
}
175+
176+
// We now have newTypes that need to be turned to tensor::CastOp.
177+
Location loc = op->getLoc();
178+
SmallVector<Value> newResults;
179+
Operation *newOp = rewriter.clone(*op);
180+
for (auto [nt, oldVal] : llvm::zip(newTypes, op->getResults())) {
181+
Type ot = oldVal.getType();
182+
OpResult newResult = newOp->getResult(oldVal.getResultNumber());
183+
if (ot == nt) {
184+
newResults.push_back(newResult);
185+
continue;
186+
}
187+
newResult.setType(nt);
188+
if (isa<RankedTensorType>(nt)) {
189+
newResults.push_back(
190+
rewriter.create<tensor::CastOp>(loc, ot, newResult));
191+
} else if (isa<MemRefType>(nt)) {
192+
newResults.push_back(
193+
rewriter.create<memref::CastOp>(loc, ot, newResult));
194+
} else {
195+
llvm_unreachable("expected RankedTensorType or MemRefType");
196+
}
197+
}
198+
199+
LLVM_DEBUG({
200+
op->getParentOp()->dump();
201+
DBGS() << "replace op " << *op << "\n";
202+
DBGS() << "with newResults " << llvm::interleaved_array(newResults)
203+
<< "\n\n\n\n";
204+
});
205+
rewriter.replaceAllOpUsesWith(op, newResults);
206+
return success();
207+
}
208+
};
209+
108210
/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
109211
///
110212
/// ```
@@ -175,6 +277,11 @@ struct ResolveShapedTypeResultDimsPass final
175277
void runOnOperation() override;
176278
};
177279

280+
struct InferStaticShapesPass final
281+
: public memref::impl::InferStaticShapesPassBase<InferStaticShapesPass> {
282+
void runOnOperation() override;
283+
};
284+
178285
} // namespace
179286

180287
void memref::populateResolveRankedShapedTypeResultDimsPatterns(
@@ -192,6 +299,11 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
192299
patterns.getContext());
193300
}
194301

302+
void memref::populateReifyToInferStaticShapePatterns(
303+
RewritePatternSet &patterns) {
304+
patterns.add<ReifyToInferStaticShapePattern>(patterns.getContext());
305+
}
306+
195307
void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
196308
RewritePatternSet patterns(&getContext());
197309
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
@@ -206,3 +318,17 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
206318
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
207319
return signalPassFailure();
208320
}
321+
322+
void InferStaticShapesPass::runOnOperation() {
323+
RewritePatternSet patterns(&getContext());
324+
patterns.add<ReifyToInferStaticShapePattern>(&getContext());
325+
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
326+
327+
SmallVector<Operation *> opsToSimplify;
328+
getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
329+
opsToSimplify.push_back(op);
330+
});
331+
(void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns,
332+
GreedyRewriteConfig().setStrictness(
333+
GreedyRewriteStrictness::ExistingOps));
334+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: mlir-opt -infer-static-shapes -split-input-file %s | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @pad_reification
4+
func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>)
5+
-> tensor<1x?x64xf32> {
6+
%pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
7+
%es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1]
8+
: tensor<64x?x64xf32> to tensor<1x?x64xf32>
9+
10+
// CHECK: tensor.pad
11+
// CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
12+
%padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
13+
^bb0(%a: index, %b: index, %c: index):
14+
tensor.yield %cst : f32
15+
} : tensor<1x?x64xf32> to tensor<1x?x64xf32>
16+
17+
return %padded : tensor<1x?x64xf32>
18+
}

0 commit comments

Comments
 (0)