From aef7a6ecc03be81cd4f00b783c79800798da1f56 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Fri, 15 Aug 2025 13:19:30 +0000 Subject: [PATCH] add rocdl.permlane16.swap and rocdl.permanlane32.swap --- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 47 ++++++++++++++++++++ mlir/test/Dialect/LLVMIR/rocdl.mlir | 16 +++++++ mlir/test/Target/LLVMIR/rocdl.mlir | 14 ++++++ 3 files changed, 77 insertions(+) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index d6761f4da21ff..9fa3ec1fc4b21 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -780,6 +780,53 @@ def ROCDL_PermlaneX16Op : ROCDL_IntrOp<"permlanex16", [], [0], }]; } +class ROCDL_ConcretePair : + Type($_self).getBody()[0]", + elem0.predicate>, + SubstLeaves< + "$_self", + "::llvm::cast<::mlir::LLVM::LLVMStructType>($_self).getBody()[1]", + elem1.predicate> + ]>, + "LLVM dialect-compatible struct of " # elem0.summary # "and" # elem1.summary, + "::mlir::LLVM::LLVMStructType">, + BuildableType<"::mlir::LLVM::LLVMStructType::getLiteral($_builder.getContext(), " + "{" # elem0.builderCall # ", " # elem1.builderCall # "})">; + +// Permlane16 swap intrinsic operation +def ROCDL_Permlane16SwapOp : ROCDL_IntrOp<"permlane16.swap", [], [], + [], 1, 0, 0, 0, + [2, 3], ["fi", "boundControl"]>, + Arguments<(ins I32:$old, I32:$src, I1Attr:$fi, I1Attr:$boundControl)> { + let results = (outs ROCDL_ConcretePair:$res); + let assemblyFormat = [{ + attr-dict $old `,` $src `,` $fi `,` $boundControl `:` `(` type($old) `,` type($src) `)` `->` type($res) + }]; + let description = [{ + Performs a `permlane16.swap` operation with the given operands, applying the + permutation specified by $fi to the provided inputs. + }]; +} + +// Permlane32 swap intrinsic operation +def ROCDL_Permlane32SwapOp : ROCDL_IntrOp<"permlane32.swap", [], [], + [], 1, 0, 0, 0, + [2, 3], ["fi", "boundControl"]>, + Arguments<(ins I32:$old, I32:$src, I1Attr:$fi, I1Attr:$boundControl)> { + let results = (outs ROCDL_ConcretePair:$res); + let assemblyFormat = [{ + attr-dict $old `,` $src `,` $fi `,` $boundControl `:` `(` type($old) `,` type($src) `)` `->` type($res) + }]; + let description = [{ + Performs a `permlane32.swap` operation with the given operands, applying the + permutation specified by $fi to the provided inputs. + }]; +} + class ROCDL_ConcreteVector : FixedVectorOfLengthAndType<[length], [elem]>, BuildableType< diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index db5271c57f573..782ef4e154440 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -1009,6 +1009,22 @@ llvm.func @rocdl.permlanex16(%src : f32) -> f32 { // ----- +llvm.func @rocdl.permlane16.swap(%src : i32) -> !llvm.struct<(i32, i32)> { + // CHECK-LABEL: rocdl.permlane16.swap + // CHECK: rocdl.permlane16.swap %{{.*}} %{{.*}} + %res = rocdl.permlane16.swap %src, %src, 0, -1 : (i32, i32) -> !llvm.struct<(i32, i32)> + llvm.return %res : !llvm.struct<(i32, i32)> +} + +llvm.func @rocdl.permlane32.swap(%src : i32) -> !llvm.struct<(i32, i32)> { + // CHECK-LABEL: rocdl.permlane32.swap + // CHECK: rocdl.permlane32.swap %{{.*}} %{{.*}} + %res = rocdl.permlane32.swap %src, %src, 0, -1 : (i32, i32) -> !llvm.struct<(i32, i32)> + llvm.return %res : !llvm.struct<(i32, i32)> +} + +// ----- + // expected-error@below {{attribute attached to unexpected op}} func.func private @expected_llvm_func() attributes { rocdl.kernel } diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index ac334eadb3c23..a464358250c38 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -941,6 +941,20 @@ llvm.func @rocdl.permlanex16(%src0 : f32, %src1 : i32, %src2 : vector<2 x f32>, llvm.return %ret0 : f32 } +llvm.func @rocdl.permlane16.swap(%src : i32) -> !llvm.struct<(i32, i32)> { + // CHECK-LABEL: rocdl.permlane16.swap + // CHECK: call { i32, i32 } @llvm.amdgcn.permlane16.swap(i32 %{{.*}}, i32 %{{.*}}, i1 false, i1 true) + %ret = rocdl.permlane16.swap %src, %src, 0, -1 : (i32, i32) -> !llvm.struct<(i32, i32)> + llvm.return %ret : !llvm.struct<(i32, i32)> +} + +llvm.func @rocdl.permlane32.swap(%src : i32) -> !llvm.struct<(i32, i32)> { + // CHECK-LABEL: rocdl.permlane32.swap + // CHECK: call { i32, i32 } @llvm.amdgcn.permlane32.swap(i32 %{{.*}}, i32 %{{.*}}, i1 false, i1 true) + %ret = rocdl.permlane32.swap %src, %src, 0, -1 : (i32, i32) -> !llvm.struct<(i32, i32)> + llvm.return %ret : !llvm.struct<(i32, i32)> +} + llvm.func @rocdl.wmma.fp8(%arg0 : vector<2 x i32>, %arg1 : vector<8xf32>) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.fp8.fp8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}}) %r0 = rocdl.wmma.f32.16x16x16.fp8_fp8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>