diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index a7fc7ddec26e6..1b85da1cf28f4 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -354,8 +354,19 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) { if (!index.has_value()) return false; + // Skip over all memref.cast ops (if any). + Operation *op = dimOp.getShapedValue().getDefiningOp(); + while (auto castOp = dyn_cast(op)) { + // Bail on unranked memrefs. + if (isa(castOp.getSource().getType())) + return false; + op = castOp.getSource().getDefiningOp(); + if (!op) + return false; + } + int64_t i = index.value(); - return TypeSwitch(dimOp.getShapedValue().getDefiningOp()) + return TypeSwitch(op) .Case( [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); }) .Default([](Operation *) { return false; });