diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h index 08afdf373f014..0fcaa96ade403 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -110,8 +110,12 @@ struct ConvolutionDimensions { FailureOr inferConvolutionDims(LinalgOp linalgOp); /// Checks whether `linalgOp` conforms to ConvolutionOpInterface. +/// By default, we require the `linalgOp` to have non-empty convolved dims +/// (implicitly non-empty `output_image` and `filter_loop`). +/// Users can loosen the constraint by setting `allowEmptyConvolvedDims` to true // TODO: embed within `isa` if possible / natural. -bool isaConvolutionOpInterface(LinalgOp linalgOp); +bool isaConvolutionOpInterface(LinalgOp linalgOp, + bool allowEmptyConvolvedDims = false); /// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`. bool isaCopyOpInterface(LinalgOp linalgOp); @@ -175,9 +179,12 @@ enum class MatchConvolutionResult; /// Checks whether `op` conforms to ConvolutionOpInterface and populates /// `dimensions` with indexes of the different kinds of dimensions when /// present. +/// If `allowEmptyConvolvedDims` is not set, we further checks whether the `op` +/// contains convolved dims. MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, - ConvolutionDimensions *dimensions = nullptr); + ConvolutionDimensions *dimensions = nullptr, + bool allowEmptyConvolvedDims = false); /// Returns the error message corresponding to the convolution checking return /// code. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index 6ee1810c2ff2b..a38b20eed3a00 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -762,13 +762,15 @@ enum class MatchConvolutionResult { NotProjectedPermutations, NonConvolutionLoop, OutputDimsNotParallel, - NonOutputDimNotReduction + NonOutputDimNotReduction, + EmptyConvolvedDims }; } // namespace mlir::linalg::detail mlir::linalg::detail::MatchConvolutionResult mlir::linalg::detail::isConvolutionInterfaceImpl( - Operation *op, ConvolutionDimensions *dimensions) { + Operation *op, ConvolutionDimensions *dimensions, + bool allowEmptyConvolvedDims) { auto linalgOp = dyn_cast(op); if (!linalgOp) return MatchConvolutionResult::NotLinalgOp; @@ -886,10 +888,12 @@ mlir::linalg::detail::isConvolutionInterfaceImpl( if (allLoopDims.size() != linalgOp.getNumLoops()) return MatchConvolutionResult::NonConvolutionLoop; + if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty()) + return MatchConvolutionResult::EmptyConvolvedDims; + if (dimensions) { - FailureOr res = - inferConvolutionDimsImpl(linalgOp, inputExprWalker, - /*allowEmptyConvolvedDims=*/true); + FailureOr res = inferConvolutionDimsImpl( + linalgOp, inputExprWalker, allowEmptyConvolvedDims); assert(succeeded(res) && "unexpected failure to infer convolution dims"); *dimensions = *res; } @@ -920,8 +924,10 @@ mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) { llvm_unreachable("unhandled MatchConvolutionResult case"); } -bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp) { - return linalg::detail::isConvolutionInterfaceImpl(linalgOp.getOperation()) == +bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp, + bool allowEmptyConvolvedDims) { + return linalg::detail::isConvolutionInterfaceImpl( + linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) == linalg::detail::MatchConvolutionResult::Success; }