From 594b498f5e03c00ed27de42b1ccf57af781c0f22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 13 Sep 2023 12:27:31 +0000 Subject: [PATCH 1/4] [mlir][memref][transform] Add new alloca_to_global op. This PR adds a new transform op that replaces `memref.alloca`s with `memref.get_global`s to newly inserted `memref.global`s. This is useful, for example, for allocations that should reside in the shared memory of a GPU, which have to be declared as globals. --- .../MemRef/TransformOps/MemRefTransformOps.td | 65 ++++++++++++++ .../TransformOps/MemRefTransformOps.cpp | 90 +++++++++++++++++++ .../dialects/_memref_transform_ops_ext.py | 58 ++++++++++++ mlir/test/Dialect/MemRef/transform-ops.mlir | 39 ++++++++ .../python/dialects/transform_memref_ext.py | 48 ++++++++++ 5 files changed, 300 insertions(+) diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td index 681759f970cb9..6a78784d74dd5 100644 --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -144,6 +144,71 @@ def ApplyResolveRankedShapedTypeResultDimsPatternsOp : Op; +def Transform_MemRefAllocaOp : Transform_ConcreteOpType<"memref.alloca">; + +def MemRefAllocaToGlobalOp : + Op, + DeclareOpInterfaceMethods]> { + let description = [{ + Inserts a new `memref.global` for each provided `memref.alloca` into the + provided module and replaces it with a `memref.get_global`. This is useful, + for example, for allocations that should reside in the shared memory of + a GPU, which have to be declared as globals. + + #### Example + + Consider the following transform op: + + ```mlir + %get_global, %global = + transform.memref.alloca_to_global %alloca in %module + : (!transform.op<"builtin.module">, !transform.op<"memref.alloca">) + -> (!transform.any_op, !transform.any_op) + ``` + + and the following input payload: + + ```mlir + module { + func.func @func() { + %alloca = memref.alloca() : memref<2x32xf32> + // usages of %alloca... + } + } + ``` + + then applying the transform op to the payload would result in the following + output IR: + + ```mlir + module { + memref.global "private" @alloc : memref<2x32xf32> + func.func @func() { + %alloca = memref.get_global @alloc : memref<2x32xf32> + // usages of %alloca... + } + } + ``` + + #### Return modes + + Emits a definite failure if not exactly one `module` payload op was provided + or any of the `alloca` payload ops is not inside that module, and succeeds + otherwise. The returned handles refer to the `memref.get_global` and + `memref.global` ops that were inserted by the transformation. + }]; + + let arguments = (ins Transform_ConcreteOpType<"builtin.module">:$module, + Transform_MemRefAllocaOp:$alloca); + let results = (outs TransformHandleTypeInterface:$get_global, + TransformHandleTypeInterface:$global); + + let assemblyFormat = [{ + $alloca `in` $module attr-dict `:` functional-type(operands, results) + }]; +} def MemRefMultiBufferOp : Op getUniqueSymbol(llvm::StringRef prefix, + ModuleOp module) { + llvm::SmallString<64> candidateNameStorage; + StringRef candidateName(prefix); + int uniqueNumber = 0; + while (true) { + if (!module.lookupSymbol(candidateName)) { + break; + } + candidateNameStorage.clear(); + candidateName = (prefix + Twine("_") + Twine(uniqueNumber)) + .toStringRef(candidateNameStorage); + uniqueNumber++; + } + return candidateName; +} +} // namespace + +DiagnosedSilenceableFailure +transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto allocaOps = state.getPayloadOps(getAlloca()); + + SmallVector globalOps; + SmallVector getGlobalOps; + + // Get `builtin.module`. + auto moduleOps = state.getPayloadOps(getModule()); + if (!llvm::hasSingleElement(moduleOps)) { + return emitDefiniteFailure() + << Twine("expected exactly one 'module' payload, but found ") + + std::to_string(llvm::range_size(moduleOps)); + } + ModuleOp module = cast(*moduleOps.begin()); + + // Transform `memref.alloca`s. + for (auto *op : allocaOps) { + auto alloca = cast(op); + MLIRContext *ctx = rewriter.getContext(); + Location loc = alloca->getLoc(); + + memref::GlobalOp globalOp; + { + // Insert a `memref.global` at the beginning of the module. + if (module != alloca->getParentOfType()) { + return emitDefiniteFailure() + << "expected 'alloca' payload to be inside 'module' payload"; + } + IRRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&module.getBodyRegion().front()); + Type resultType = alloca.getResult().getType(); + llvm::SmallString<64> symName = getUniqueSymbol("alloca", module); + // XXX: Add a better builder for this. + globalOp = rewriter.create( + loc, StringAttr::get(ctx, symName), StringAttr::get(ctx, "private"), + TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{}); + } + + // Replace the `memref.alloca` with a `memref.get_global` accessing the + // global symbol inserted above. + rewriter.setInsertionPoint(alloca); + auto getGlobalOp = rewriter.replaceOpWithNewOp( + alloca, globalOp.getType(), globalOp.getName()); + + globalOps.push_back(globalOp); + getGlobalOps.push_back(getGlobalOp); + } + + // Assemble results. + results.set(getGlobal().cast(), globalOps); + results.set(getGetGlobal().cast(), getGlobalOps); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::MemRefAllocaToGlobalOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getModule(), effects); + producesHandle(getGlobal(), effects); + producesHandle(getGetGlobal(), effects); + consumesHandle(getAlloca(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // MemRefMultiBufferOp //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py index 4afe8e7b887f6..56dcfbe5655e9 100644 --- a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py @@ -11,6 +11,64 @@ from typing import Optional, overload, Union +class MemRefAllocaToGlobalOp: + """Specialization for MemRefAllocaToGlobalOp class.""" + + @overload + def __init__( + self, + get_global_type: Type, + global_type: Type, + module: Union[Operation, OpView, Value], + alloca: Union[Operation, OpView, Value], + *, + loc=None, + ip=None + ): + ... + + @overload + def __init__( + self, + module: Union[Operation, OpView, Value], + alloca: Union[Operation, OpView, Value], + *, + loc=None, + ip=None + ): + ... + + def __init__( + self, + get_global_type_or_module: Union[Operation, OpView, Type, Value], + global_type_or_alloca: Union[Operation, OpView, Type, Value], + module_or_none: Optional[Union[Operation, OpView, Value]] = None, + alloca_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None + ): + if isinstance(get_global_type_or_module, Type): + get_global_type = get_global_type_or_module + global_type = global_type_or_alloca + module = module_or_none + alloca = alloca_or_none + else: + get_global_type = transform.AnyOpType.get() + global_type = transform.AnyOpType.get() + module = get_global_type_or_module + alloca = global_type_or_alloca + + super().__init__( + get_global_type, + global_type, + module, + alloca, + loc=loc, + ip=ip, + ) + + class MemRefMultiBufferOp: """Specialization for MemRefMultiBufferOp class.""" diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir index b19db447af1c2..aeeb2a6b0abed 100644 --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -1,5 +1,44 @@ // RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect -split-input-file | FileCheck %s +// CHECK-DAG: memref.global "private" @[[ALLOC0:alloc.*]] : memref<2x32xf32> +// CHECK-DAG: memref.global "private" @[[ALLOC1:alloc.*]] : memref<2x32xf32> + +// CHECK: func.func @func( +func.func @func(%arg0: f32) { + %c3 = arith.constant 3 : index + %c1 = arith.constant 1 : index + // CHECK: scf.forall + scf.forall (%arg1, %arg2) in (%c3, %c1) { + // CHECK-DAG: %[[MR0:.*]] = memref.get_global @[[ALLOC0]] : memref<2x32xf32> + // CHECK-DAG: %[[MR1:.*]] = memref.get_global @[[ALLOC1]] : memref<2x32xf32> + // CHECK-DAG: memref.store %{{.*}}, %[[MR0]][%{{.*}}, %{{.*}}] : memref<2x32xf32> + // CHECK-DAG: memref.store %{{.*}}, %[[MR1]][%{{.*}}, %{{.*}}] : memref<2x32xf32> + %alloca = memref.alloca() : memref<2x32xf32> + %alloca_0 = memref.alloca() : memref<2x32xf32> + memref.store %arg0, %alloca[%arg1, %arg2] : memref<2x32xf32> + memref.store %arg0, %alloca_0[%arg1, %arg2] : memref<2x32xf32> + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %alloca = transform.structured.match ops{["memref.alloca"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %module = transform.structured.match ops{["builtin.module"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %alloca_typed = transform.cast %alloca + : !transform.any_op to !transform.op<"memref.alloca"> + %module_typed = transform.cast %module + : !transform.any_op to !transform.op<"builtin.module"> + %get_global, %global = + transform.memref.alloca_to_global %alloca_typed in %module_typed + : (!transform.op<"builtin.module">, !transform.op<"memref.alloca">) + -> (!transform.any_op, !transform.any_op) +} + +// ----- + // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> diff --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py index f89005cb2f86d..8278019bbab3b 100644 --- a/mlir/test/python/dialects/transform_memref_ext.py +++ b/mlir/test/python/dialects/transform_memref_ext.py @@ -16,6 +16,54 @@ def run(f): return f +@run +def testMemRefAllocaToAllocOpCompact(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("memref.alloc"), + ) + with InsertionPoint(sequence.body): + module = transform.CastOp( + transform.OperationType.get("builtin.module"), sequence.bodyTarget + ) + alloca = transform.CastOp( + transform.OperationType.get("memref.alloca"), sequence.bodyTarget + ) + memref.MemRefAllocaToGlobalOp(module, alloca) + transform.YieldOp() + # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact + # CHECK: = transform.memref.alloca_to_global + # CHECK-SAME: (!transform.op<"builtin.module">, !transform.op<"memref.alloca">) + # CHECK-SAME: -> (!transform.any_op, !transform.any_op) + + +@run +def testMemRefAllocaToAllocOpTyped(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("memref.alloc"), + ) + with InsertionPoint(sequence.body): + module = transform.CastOp( + transform.OperationType.get("builtin.module"), sequence.bodyTarget + ) + alloca = transform.CastOp( + transform.OperationType.get("memref.alloca"), sequence.bodyTarget + ) + memref.MemRefAllocaToGlobalOp( + transform.OperationType.get("memref.get_global"), + transform.OperationType.get("memref.global"), + module, + alloca, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpTyped + # CHECK: = transform.memref.alloca_to_global + # CHECK-SAME: -> (!transform.op<"memref.get_global">, !transform.op<"memref.global">) + + @run def testMemRefMultiBufferOpCompact(): sequence = transform.SequenceOp( From b8f439e2fe4ba708beb906c9cac76f50845f94a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 21 Sep 2023 09:21:58 +0000 Subject: [PATCH 2/4] Address comments from @ftynse's review. In particular: * Accept any op type with `SymbolTable` trait as containing op rather than only `builtin.module` and rename op argument accordingly. * Use `SymbolTable::insert` to unique the name of the globals rather than some hand-rolled function. * Use more sane semantics in Python mix-in test. --- .../MemRef/TransformOps/MemRefTransformOps.td | 22 ++++--- .../TransformOps/MemRefTransformOps.cpp | 63 ++++++++----------- mlir/test/Dialect/MemRef/transform-ops.mlir | 19 +++--- .../python/dialects/transform_memref_ext.py | 20 ++---- 4 files changed, 54 insertions(+), 70 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td index 6a78784d74dd5..af2401a80b898 100644 --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -153,9 +153,10 @@ def MemRefAllocaToGlobalOp : DeclareOpInterfaceMethods]> { let description = [{ Inserts a new `memref.global` for each provided `memref.alloca` into the - provided module and replaces it with a `memref.get_global`. This is useful, - for example, for allocations that should reside in the shared memory of - a GPU, which have to be declared as globals. + provided symbol table (e.g., a `builtin.module`) and replaces it with a + `memref.get_global`. This is useful, for example, for allocations that + should reside in the shared memory of a GPU, which have to be declared as + globals. #### Example @@ -194,19 +195,20 @@ def MemRefAllocaToGlobalOp : #### Return modes - Emits a definite failure if not exactly one `module` payload op was provided - or any of the `alloca` payload ops is not inside that module, and succeeds - otherwise. The returned handles refer to the `memref.get_global` and - `memref.global` ops that were inserted by the transformation. + Emits a definite failure if not exactly one symbol table payload op was + provided or any of the `alloca` payload ops is not inside that symbol table + op, and succeeds otherwise. The returned handles refer to the + `memref.get_global` and `memref.global` ops that were inserted by the + transformation. }]; - let arguments = (ins Transform_ConcreteOpType<"builtin.module">:$module, + let arguments = (ins TransformHandleTypeInterface:$symbolTable, Transform_MemRefAllocaOp:$alloca); - let results = (outs TransformHandleTypeInterface:$get_global, + let results = (outs TransformHandleTypeInterface:$getGlobal, TransformHandleTypeInterface:$global); let assemblyFormat = [{ - $alloca `in` $module attr-dict `:` functional-type(operands, results) + $alloca `in` $symbolTable attr-dict `:` functional-type(operands, results) }]; } diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp index 7467359da83c3..e5e19b4edbc5a 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -130,25 +130,6 @@ void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp:: // AllocaToGlobalOp //===----------------------------------------------------------------------===// -namespace { -static llvm::SmallString<64> getUniqueSymbol(llvm::StringRef prefix, - ModuleOp module) { - llvm::SmallString<64> candidateNameStorage; - StringRef candidateName(prefix); - int uniqueNumber = 0; - while (true) { - if (!module.lookupSymbol(candidateName)) { - break; - } - candidateNameStorage.clear(); - candidateName = (prefix + Twine("_") + Twine(uniqueNumber)) - .toStringRef(candidateNameStorage); - uniqueNumber++; - } - return candidateName; -} -} // namespace - DiagnosedSilenceableFailure transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, @@ -158,14 +139,25 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter, SmallVector globalOps; SmallVector getGlobalOps; - // Get `builtin.module`. - auto moduleOps = state.getPayloadOps(getModule()); - if (!llvm::hasSingleElement(moduleOps)) { + // Get containing symbol table op. + auto symbolTableOps = state.getPayloadOps(getSymbolTable()); + if (!llvm::hasSingleElement(symbolTableOps)) { return emitDefiniteFailure() - << Twine("expected exactly one 'module' payload, but found ") + - std::to_string(llvm::range_size(moduleOps)); + << Twine("expected exactly one 'symbolTable' payload, but found ") + + std::to_string(llvm::range_size(symbolTableOps)); + } + Operation *symbolTableOp = *symbolTableOps.begin(); + if (!symbolTableOp->hasTrait()) { + return emitDefiniteFailure() << Twine( + "expected 'symbolTable' payload to have 'SymbolTable' trait"); + } + SymbolTable symbolTable(symbolTableOp); + + { + size_t numAllocaOps = llvm::range_size(allocaOps); + globalOps.reserve(numAllocaOps); + getGlobalOps.reserve(numAllocaOps); } - ModuleOp module = cast(*moduleOps.begin()); // Transform `memref.alloca`s. for (auto *op : allocaOps) { @@ -175,19 +167,18 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter, memref::GlobalOp globalOp; { - // Insert a `memref.global` at the beginning of the module. - if (module != alloca->getParentOfType()) { - return emitDefiniteFailure() - << "expected 'alloca' payload to be inside 'module' payload"; + // Insert a `memref.global` into the symbol table. + if (symbolTable.getOp() != SymbolTable::getNearestSymbolTable(op)) { + return emitDefiniteFailure() << "expected 'alloca' payload to be " + "inside 'symbolTable' payload"; } - IRRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&module.getBodyRegion().front()); Type resultType = alloca.getResult().getType(); - llvm::SmallString<64> symName = getUniqueSymbol("alloca", module); - // XXX: Add a better builder for this. - globalOp = rewriter.create( - loc, StringAttr::get(ctx, symName), StringAttr::get(ctx, "private"), + // TODO: Add a better builder for this. + OpBuilder builder(rewriter.getContext()); + globalOp = builder.create( + loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"), TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{}); + symbolTable.insert(globalOp); } // Replace the `memref.alloca` with a `memref.get_global` accessing the @@ -209,7 +200,7 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter, void transform::MemRefAllocaToGlobalOp::getEffects( SmallVectorImpl &effects) { - onlyReadsHandle(getModule(), effects); + onlyReadsHandle(getSymbolTable(), effects); producesHandle(getGlobal(), effects); producesHandle(getGetGlobal(), effects); consumesHandle(getAlloca(), effects); diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir index aeeb2a6b0abed..e22e3d62190c4 100644 --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -3,20 +3,19 @@ // CHECK-DAG: memref.global "private" @[[ALLOC0:alloc.*]] : memref<2x32xf32> // CHECK-DAG: memref.global "private" @[[ALLOC1:alloc.*]] : memref<2x32xf32> -// CHECK: func.func @func( -func.func @func(%arg0: f32) { - %c3 = arith.constant 3 : index - %c1 = arith.constant 1 : index - // CHECK: scf.forall - scf.forall (%arg1, %arg2) in (%c3, %c1) { +// CHECK-DAG: func.func @func(%[[LB:.*]]: index, %[[UB:.*]]: index) +func.func @func(%lb: index, %ub: index) { + // CHECK-DAG: scf.forall (%[[ARG0:.*]], %[[ARG1:.*]]) in (%[[LB]], %[[UB]]) + scf.forall (%arg0, %arg1) in (%lb, %ub) { // CHECK-DAG: %[[MR0:.*]] = memref.get_global @[[ALLOC0]] : memref<2x32xf32> // CHECK-DAG: %[[MR1:.*]] = memref.get_global @[[ALLOC1]] : memref<2x32xf32> // CHECK-DAG: memref.store %{{.*}}, %[[MR0]][%{{.*}}, %{{.*}}] : memref<2x32xf32> // CHECK-DAG: memref.store %{{.*}}, %[[MR1]][%{{.*}}, %{{.*}}] : memref<2x32xf32> - %alloca = memref.alloca() : memref<2x32xf32> - %alloca_0 = memref.alloca() : memref<2x32xf32> - memref.store %arg0, %alloca[%arg1, %arg2] : memref<2x32xf32> - memref.store %arg0, %alloca_0[%arg1, %arg2] : memref<2x32xf32> + %cst = arith.constant 0.0 : f32 + %mr0 = memref.alloca() : memref<2x32xf32> + %mr1 = memref.alloca() : memref<2x32xf32> + memref.store %cst, %mr0[%arg0, %arg1] : memref<2x32xf32> + memref.store %cst, %mr1[%arg0, %arg1] : memref<2x32xf32> } return } diff --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py index 8278019bbab3b..6f622cbde7085 100644 --- a/mlir/test/python/dialects/transform_memref_ext.py +++ b/mlir/test/python/dialects/transform_memref_ext.py @@ -21,15 +21,11 @@ def testMemRefAllocaToAllocOpCompact(): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], - transform.OperationType.get("memref.alloc"), + transform.OperationType.get("builtin.module"), + [transform.OperationType.get("memref.alloca")], ) with InsertionPoint(sequence.body): - module = transform.CastOp( - transform.OperationType.get("builtin.module"), sequence.bodyTarget - ) - alloca = transform.CastOp( - transform.OperationType.get("memref.alloca"), sequence.bodyTarget - ) + module, alloca = sequence.body.arguments memref.MemRefAllocaToGlobalOp(module, alloca) transform.YieldOp() # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact @@ -43,15 +39,11 @@ def testMemRefAllocaToAllocOpTyped(): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], - transform.OperationType.get("memref.alloc"), + transform.OperationType.get("builtin.module"), + [transform.OperationType.get("memref.alloca")], ) with InsertionPoint(sequence.body): - module = transform.CastOp( - transform.OperationType.get("builtin.module"), sequence.bodyTarget - ) - alloca = transform.CastOp( - transform.OperationType.get("memref.alloca"), sequence.bodyTarget - ) + module, alloca = sequence.body.arguments memref.MemRefAllocaToGlobalOp( transform.OperationType.get("memref.get_global"), transform.OperationType.get("memref.global"), From a91c93ddcf4535405644cdc477229a5f982892ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 21 Sep 2023 13:37:07 +0000 Subject: [PATCH 3/4] Remove symbolTable op arg and simplify tests. --- .../MemRef/TransformOps/MemRefTransformOps.td | 18 ++++------- .../TransformOps/MemRefTransformOps.cpp | 32 ++++--------------- mlir/test/Dialect/MemRef/transform-ops.mlir | 13 ++------ 3 files changed, 16 insertions(+), 47 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td index af2401a80b898..d7bd8410e360a 100644 --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -153,7 +153,7 @@ def MemRefAllocaToGlobalOp : DeclareOpInterfaceMethods]> { let description = [{ Inserts a new `memref.global` for each provided `memref.alloca` into the - provided symbol table (e.g., a `builtin.module`) and replaces it with a + nearest symbol table (e.g., a `builtin.module`) and replaces it with a `memref.get_global`. This is useful, for example, for allocations that should reside in the shared memory of a GPU, which have to be declared as globals. @@ -164,8 +164,8 @@ def MemRefAllocaToGlobalOp : ```mlir %get_global, %global = - transform.memref.alloca_to_global %alloca in %module - : (!transform.op<"builtin.module">, !transform.op<"memref.alloca">) + transform.memref.alloca_to_global %alloca + : (!transform.op<"memref.alloca">) -> (!transform.any_op, !transform.any_op) ``` @@ -195,20 +195,16 @@ def MemRefAllocaToGlobalOp : #### Return modes - Emits a definite failure if not exactly one symbol table payload op was - provided or any of the `alloca` payload ops is not inside that symbol table - op, and succeeds otherwise. The returned handles refer to the - `memref.get_global` and `memref.global` ops that were inserted by the - transformation. + Succeeds always. The returned handles refer to the `memref.get_global` and + `memref.global` ops that were inserted by the transformation. }]; - let arguments = (ins TransformHandleTypeInterface:$symbolTable, - Transform_MemRefAllocaOp:$alloca); + let arguments = (ins Transform_MemRefAllocaOp:$alloca); let results = (outs TransformHandleTypeInterface:$getGlobal, TransformHandleTypeInterface:$global); let assemblyFormat = [{ - $alloca `in` $symbolTable attr-dict `:` functional-type(operands, results) + $alloca attr-dict `:` functional-type(operands, results) }]; } diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp index e5e19b4edbc5a..eed29efcaaada 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -139,26 +139,6 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter, SmallVector globalOps; SmallVector getGlobalOps; - // Get containing symbol table op. - auto symbolTableOps = state.getPayloadOps(getSymbolTable()); - if (!llvm::hasSingleElement(symbolTableOps)) { - return emitDefiniteFailure() - << Twine("expected exactly one 'symbolTable' payload, but found ") + - std::to_string(llvm::range_size(symbolTableOps)); - } - Operation *symbolTableOp = *symbolTableOps.begin(); - if (!symbolTableOp->hasTrait()) { - return emitDefiniteFailure() << Twine( - "expected 'symbolTable' payload to have 'SymbolTable' trait"); - } - SymbolTable symbolTable(symbolTableOp); - - { - size_t numAllocaOps = llvm::range_size(allocaOps); - globalOps.reserve(numAllocaOps); - getGlobalOps.reserve(numAllocaOps); - } - // Transform `memref.alloca`s. for (auto *op : allocaOps) { auto alloca = cast(op); @@ -167,14 +147,15 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter, memref::GlobalOp globalOp; { + // Find nearest symbol table. + Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op); + assert(symbolTableOp && "expected alloca payload to be in symbol table"); + SymbolTable symbolTable(symbolTableOp); + // Insert a `memref.global` into the symbol table. - if (symbolTable.getOp() != SymbolTable::getNearestSymbolTable(op)) { - return emitDefiniteFailure() << "expected 'alloca' payload to be " - "inside 'symbolTable' payload"; - } Type resultType = alloca.getResult().getType(); - // TODO: Add a better builder for this. OpBuilder builder(rewriter.getContext()); + // TODO: Add a better builder for this. globalOp = builder.create( loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"), TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{}); @@ -200,7 +181,6 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter, void transform::MemRefAllocaToGlobalOp::getEffects( SmallVectorImpl &effects) { - onlyReadsHandle(getSymbolTable(), effects); producesHandle(getGlobal(), effects); producesHandle(getGetGlobal(), effects); consumesHandle(getAlloca(), effects); diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir index e22e3d62190c4..68fea1f840295 100644 --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -23,16 +23,9 @@ func.func @func(%lb: index, %ub: index) { transform.sequence failures(propagate) { ^bb1(%arg0: !transform.any_op): %alloca = transform.structured.match ops{["memref.alloca"]} in %arg0 - : (!transform.any_op) -> !transform.any_op - %module = transform.structured.match ops{["builtin.module"]} in %arg0 - : (!transform.any_op) -> !transform.any_op - %alloca_typed = transform.cast %alloca - : !transform.any_op to !transform.op<"memref.alloca"> - %module_typed = transform.cast %module - : !transform.any_op to !transform.op<"builtin.module"> - %get_global, %global = - transform.memref.alloca_to_global %alloca_typed in %module_typed - : (!transform.op<"builtin.module">, !transform.op<"memref.alloca">) + : (!transform.any_op) -> !transform.op<"memref.alloca"> + %get_global, %global = transform.memref.alloca_to_global %alloca + : (!transform.op<"memref.alloca">) -> (!transform.any_op, !transform.any_op) } From 05598559cf8e2c4a83ab77f74d84b9de3e9edfa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 21 Sep 2023 14:40:07 +0000 Subject: [PATCH 4/4] Remove symbolTable op arg in Python mix-in. --- .../dialects/_memref_transform_ops_ext.py | 26 +++++-------------- .../python/dialects/transform_memref_ext.py | 15 ++++------- 2 files changed, 12 insertions(+), 29 deletions(-) diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py index 56dcfbe5655e9..1cc00bdcbf381 100644 --- a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py @@ -19,7 +19,6 @@ def __init__( self, get_global_type: Type, global_type: Type, - module: Union[Operation, OpView, Value], alloca: Union[Operation, OpView, Value], *, loc=None, @@ -28,41 +27,30 @@ def __init__( ... @overload - def __init__( - self, - module: Union[Operation, OpView, Value], - alloca: Union[Operation, OpView, Value], - *, - loc=None, - ip=None - ): + def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None): ... def __init__( self, - get_global_type_or_module: Union[Operation, OpView, Type, Value], - global_type_or_alloca: Union[Operation, OpView, Type, Value], - module_or_none: Optional[Union[Operation, OpView, Value]] = None, + get_global_type_or_alloca: Union[Operation, OpView, Type, Value], + global_type_or_none: Optional[Type] = None, alloca_or_none: Optional[Union[Operation, OpView, Value]] = None, *, loc=None, ip=None ): - if isinstance(get_global_type_or_module, Type): - get_global_type = get_global_type_or_module - global_type = global_type_or_alloca - module = module_or_none + if isinstance(get_global_type_or_alloca, Type): + get_global_type = get_global_type_or_alloca + global_type = global_type_or_none alloca = alloca_or_none else: get_global_type = transform.AnyOpType.get() global_type = transform.AnyOpType.get() - module = get_global_type_or_module - alloca = global_type_or_alloca + alloca = get_global_type_or_alloca super().__init__( get_global_type, global_type, - module, alloca, loc=loc, ip=ip, diff --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py index 6f622cbde7085..e7d871c9eac8c 100644 --- a/mlir/test/python/dialects/transform_memref_ext.py +++ b/mlir/test/python/dialects/transform_memref_ext.py @@ -21,16 +21,14 @@ def testMemRefAllocaToAllocOpCompact(): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], - transform.OperationType.get("builtin.module"), - [transform.OperationType.get("memref.alloca")], + transform.OperationType.get("memref.alloca"), ) with InsertionPoint(sequence.body): - module, alloca = sequence.body.arguments - memref.MemRefAllocaToGlobalOp(module, alloca) + memref.MemRefAllocaToGlobalOp(sequence.bodyTarget) transform.YieldOp() # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact # CHECK: = transform.memref.alloca_to_global - # CHECK-SAME: (!transform.op<"builtin.module">, !transform.op<"memref.alloca">) + # CHECK-SAME: (!transform.op<"memref.alloca">) # CHECK-SAME: -> (!transform.any_op, !transform.any_op) @@ -39,16 +37,13 @@ def testMemRefAllocaToAllocOpTyped(): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], - transform.OperationType.get("builtin.module"), - [transform.OperationType.get("memref.alloca")], + transform.OperationType.get("memref.alloca"), ) with InsertionPoint(sequence.body): - module, alloca = sequence.body.arguments memref.MemRefAllocaToGlobalOp( transform.OperationType.get("memref.get_global"), transform.OperationType.get("memref.global"), - module, - alloca, + sequence.bodyTarget, ) transform.YieldOp() # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpTyped