diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td index 681759f970cb9..d7bd8410e360a 100644 --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -144,6 +144,69 @@ def ApplyResolveRankedShapedTypeResultDimsPatternsOp : Op; +def Transform_MemRefAllocaOp : Transform_ConcreteOpType<"memref.alloca">; + +def MemRefAllocaToGlobalOp : + Op, + DeclareOpInterfaceMethods]> { + let description = [{ + Inserts a new `memref.global` for each provided `memref.alloca` into the + nearest symbol table (e.g., a `builtin.module`) and replaces it with a + `memref.get_global`. This is useful, for example, for allocations that + should reside in the shared memory of a GPU, which have to be declared as + globals. + + #### Example + + Consider the following transform op: + + ```mlir + %get_global, %global = + transform.memref.alloca_to_global %alloca + : (!transform.op<"memref.alloca">) + -> (!transform.any_op, !transform.any_op) + ``` + + and the following input payload: + + ```mlir + module { + func.func @func() { + %alloca = memref.alloca() : memref<2x32xf32> + // usages of %alloca... + } + } + ``` + + then applying the transform op to the payload would result in the following + output IR: + + ```mlir + module { + memref.global "private" @alloc : memref<2x32xf32> + func.func @func() { + %alloca = memref.get_global @alloc : memref<2x32xf32> + // usages of %alloca... + } + } + ``` + + #### Return modes + + Succeeds always. The returned handles refer to the `memref.get_global` and + `memref.global` ops that were inserted by the transformation. + }]; + + let arguments = (ins Transform_MemRefAllocaOp:$alloca); + let results = (outs TransformHandleTypeInterface:$getGlobal, + TransformHandleTypeInterface:$global); + + let assemblyFormat = [{ + $alloca attr-dict `:` functional-type(operands, results) + }]; +} def MemRefMultiBufferOp : Op globalOps; + SmallVector getGlobalOps; + + // Transform `memref.alloca`s. + for (auto *op : allocaOps) { + auto alloca = cast(op); + MLIRContext *ctx = rewriter.getContext(); + Location loc = alloca->getLoc(); + + memref::GlobalOp globalOp; + { + // Find nearest symbol table. + Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op); + assert(symbolTableOp && "expected alloca payload to be in symbol table"); + SymbolTable symbolTable(symbolTableOp); + + // Insert a `memref.global` into the symbol table. + Type resultType = alloca.getResult().getType(); + OpBuilder builder(rewriter.getContext()); + // TODO: Add a better builder for this. + globalOp = builder.create( + loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"), + TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{}); + symbolTable.insert(globalOp); + } + + // Replace the `memref.alloca` with a `memref.get_global` accessing the + // global symbol inserted above. + rewriter.setInsertionPoint(alloca); + auto getGlobalOp = rewriter.replaceOpWithNewOp( + alloca, globalOp.getType(), globalOp.getName()); + + globalOps.push_back(globalOp); + getGlobalOps.push_back(getGlobalOp); + } + + // Assemble results. + results.set(getGlobal().cast(), globalOps); + results.set(getGetGlobal().cast(), getGlobalOps); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::MemRefAllocaToGlobalOp::getEffects( + SmallVectorImpl &effects) { + producesHandle(getGlobal(), effects); + producesHandle(getGetGlobal(), effects); + consumesHandle(getAlloca(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // MemRefMultiBufferOp //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py index 4afe8e7b887f6..1cc00bdcbf381 100644 --- a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py @@ -11,6 +11,52 @@ from typing import Optional, overload, Union +class MemRefAllocaToGlobalOp: + """Specialization for MemRefAllocaToGlobalOp class.""" + + @overload + def __init__( + self, + get_global_type: Type, + global_type: Type, + alloca: Union[Operation, OpView, Value], + *, + loc=None, + ip=None + ): + ... + + @overload + def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None): + ... + + def __init__( + self, + get_global_type_or_alloca: Union[Operation, OpView, Type, Value], + global_type_or_none: Optional[Type] = None, + alloca_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None + ): + if isinstance(get_global_type_or_alloca, Type): + get_global_type = get_global_type_or_alloca + global_type = global_type_or_none + alloca = alloca_or_none + else: + get_global_type = transform.AnyOpType.get() + global_type = transform.AnyOpType.get() + alloca = get_global_type_or_alloca + + super().__init__( + get_global_type, + global_type, + alloca, + loc=loc, + ip=ip, + ) + + class MemRefMultiBufferOp: """Specialization for MemRefMultiBufferOp class.""" diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir index b19db447af1c2..68fea1f840295 100644 --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -1,5 +1,36 @@ // RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect -split-input-file | FileCheck %s +// CHECK-DAG: memref.global "private" @[[ALLOC0:alloc.*]] : memref<2x32xf32> +// CHECK-DAG: memref.global "private" @[[ALLOC1:alloc.*]] : memref<2x32xf32> + +// CHECK-DAG: func.func @func(%[[LB:.*]]: index, %[[UB:.*]]: index) +func.func @func(%lb: index, %ub: index) { + // CHECK-DAG: scf.forall (%[[ARG0:.*]], %[[ARG1:.*]]) in (%[[LB]], %[[UB]]) + scf.forall (%arg0, %arg1) in (%lb, %ub) { + // CHECK-DAG: %[[MR0:.*]] = memref.get_global @[[ALLOC0]] : memref<2x32xf32> + // CHECK-DAG: %[[MR1:.*]] = memref.get_global @[[ALLOC1]] : memref<2x32xf32> + // CHECK-DAG: memref.store %{{.*}}, %[[MR0]][%{{.*}}, %{{.*}}] : memref<2x32xf32> + // CHECK-DAG: memref.store %{{.*}}, %[[MR1]][%{{.*}}, %{{.*}}] : memref<2x32xf32> + %cst = arith.constant 0.0 : f32 + %mr0 = memref.alloca() : memref<2x32xf32> + %mr1 = memref.alloca() : memref<2x32xf32> + memref.store %cst, %mr0[%arg0, %arg1] : memref<2x32xf32> + memref.store %cst, %mr1[%arg0, %arg1] : memref<2x32xf32> + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %alloca = transform.structured.match ops{["memref.alloca"]} in %arg0 + : (!transform.any_op) -> !transform.op<"memref.alloca"> + %get_global, %global = transform.memref.alloca_to_global %alloca + : (!transform.op<"memref.alloca">) + -> (!transform.any_op, !transform.any_op) +} + +// ----- + // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> diff --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py index f89005cb2f86d..e7d871c9eac8c 100644 --- a/mlir/test/python/dialects/transform_memref_ext.py +++ b/mlir/test/python/dialects/transform_memref_ext.py @@ -16,6 +16,41 @@ def run(f): return f +@run +def testMemRefAllocaToAllocOpCompact(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("memref.alloca"), + ) + with InsertionPoint(sequence.body): + memref.MemRefAllocaToGlobalOp(sequence.bodyTarget) + transform.YieldOp() + # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact + # CHECK: = transform.memref.alloca_to_global + # CHECK-SAME: (!transform.op<"memref.alloca">) + # CHECK-SAME: -> (!transform.any_op, !transform.any_op) + + +@run +def testMemRefAllocaToAllocOpTyped(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("memref.alloca"), + ) + with InsertionPoint(sequence.body): + memref.MemRefAllocaToGlobalOp( + transform.OperationType.get("memref.get_global"), + transform.OperationType.get("memref.global"), + sequence.bodyTarget, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpTyped + # CHECK: = transform.memref.alloca_to_global + # CHECK-SAME: -> (!transform.op<"memref.get_global">, !transform.op<"memref.global">) + + @run def testMemRefMultiBufferOpCompact(): sequence = transform.SequenceOp(