diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td index a0eb5ff00cb9f..9b588eb610e51 100644 --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -85,6 +85,7 @@ def OneShotBufferizeOp DefaultValuedAttr:$allow_return_allocs_from_loops, DefaultValuedAttr:$allow_unknown_ops, DefaultValuedAttr:$bufferize_function_boundaries, + DefaultValuedAttr:$dump_alias_sets, DefaultValuedAttr:$test_analysis_only, DefaultValuedAttr:$print_conflicts, DefaultValuedAttr:$memcpy_op); diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp index 354ed162a15ea..b7db4917a4138 100644 --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -47,6 +47,10 @@ void transform::BufferLoopHoistingOp::getEffects( LogicalResult transform::OneShotBufferizeOp::verify() { if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy") return emitOpError() << "unsupported memcpy op"; + if (getPrintConflicts() && !getTestAnalysisOnly()) + return emitOpError() << "'print_conflicts' requires 'test_analysis_only'"; + if (getDumpAliasSets() && !getTestAnalysisOnly()) + return emitOpError() << "'dump_alias_sets' requires 'test_analysis_only'"; return success(); } @@ -58,6 +62,7 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, options.allowReturnAllocsFromLoops = getAllowReturnAllocsFromLoops(); options.allowUnknownOps = getAllowUnknownOps(); options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries(); + options.dumpAliasSets = getDumpAliasSets(); options.testAnalysisOnly = getTestAnalysisOnly(); options.printConflicts = getPrintConflicts(); if (getFunctionBoundaryTypeConversion().has_value())