diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index bd5ea9fd83781..81e25f7537cb0 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -110,23 +110,34 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface Variadic: $offsets, Variadic: $shape, Variadic: $strides, - DenseI64ArrayAttr: $const_offsets, + OptionalAttr: $const_offsets, OptionalAttr: $const_shape, OptionalAttr: $const_strides ); - let results = (outs XeGPU_TensorDesc: $TensorDesc); let assemblyFormat = [{ $source `` - custom($offsets, $const_offsets) - (`,` custom($shape, $const_shape)^ - `,` custom($strides, $const_strides))? + custom($offsets, $const_offsets) + (`,` `shape` `:` custom($shape, $const_shape)^ + `,` `strides``:` custom($strides, $const_strides))? attr-dict `:` type($source) `->` qualified(type($TensorDesc)) }]; + let results = (outs XeGPU_TensorDesc: $TensorDesc); + let hasVerifier = 1; let builders = [ + OpBuilder<(ins "Type": $tdesc, "TypedValue": $source)>, + + OpBuilder<(ins "Type": $tdesc, "TypedValue ": $source, + "llvm::ArrayRef": $shape, + "llvm::ArrayRef": $strides)>, + + OpBuilder<(ins "Type": $tdesc, "TypedValue ": $source, + "llvm::ArrayRef": $shape, + "llvm::ArrayRef": $strides)>, + OpBuilder<(ins "Type": $tdesc, "TypedValue": $source, "llvm::ArrayRef": $offsets)>, @@ -163,7 +174,17 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface } ArrayRef getStaticOffsets(){ - return getConstOffsets(); + auto attr = getConstOffsetsAttr(); + + if (attr) + return attr; + + int64_t rank = getMixedSizes().size(); + + setConstOffsets(llvm::SmallVector(rank, 0)); + + attr = getConstOffsetsAttr(); + return attr; } /// wrapper for matching with OffsetSizeAndStrideOpInterface @@ -172,10 +193,16 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface /// and `const_shape` will be used to represent the shape of /// source operand. They overide static shape from source memref type. ArrayRef getStaticSizes() { + /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks + static llvm::SmallVector emptyShape; + auto attr = getConstShapeAttr(); - if (llvm::isa(getSourceType()) || attr) + if (attr) return attr; + if (llvm::isa(getSourceType())) + return emptyShape; + auto memrefType = llvm::dyn_cast(getSourceType()); assert(memrefType && "Incorrect use of getStaticSizes"); return memrefType.getShape(); @@ -187,9 +214,15 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface /// and `const_strides` will be used to represent the strides of /// source operand. They overide static strides from source memref type. ArrayRef getStaticStrides() { + /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks + static llvm::SmallVector emptyStrides; + auto attr = getConstStridesAttr(); - if (llvm::isa(getSourceType()) || attr) + if (attr) return attr; + + if (llvm::isa(getSourceType())) + return emptyStrides; auto memrefType = llvm::dyn_cast(getSourceType()); assert(memrefType && "Incorrect use of getStaticStrides"); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index ef7cd1424e7a4..78cbf884a1911 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/Support/Debug.h" @@ -112,6 +113,68 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// + +void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, + Type tdesc, TypedValue source) { + [[maybe_unused]] auto ty = source.getType(); + assert(ty.hasStaticShape() && "expecting a memref with static shape"); + + build(builder, state, tdesc, source, ValueRange({}) /* dynamic offsets */, + ValueRange({}) /* empty dynamic shape */, + ValueRange({}) /* empty dynamic strides */, + DenseI64ArrayAttr({}) /* const offsets */, + DenseI64ArrayAttr({}) /* empty const shape*/, + DenseI64ArrayAttr({}) /* empty const strides*/); +} + +void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, + Type tdesc, TypedValue source, + llvm::ArrayRef shape, + llvm::ArrayRef strides) { + assert(shape.size() && strides.size() && shape.size() == strides.size() && + "Shape and strides must be present and of equal size for ui64 " + "initialization."); + + llvm::SmallVector staticShape; + llvm::SmallVector staticStrides; + llvm::SmallVector dynamicShape; + llvm::SmallVector dynamicStrides; + + dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); + dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); + + auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); + auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); + + build(builder, state, tdesc, source, ValueRange({}), dynamicShape, + dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, + staticStridesAttr); +} + +void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, + Type tdesc, TypedValue source, + llvm::ArrayRef shape, + llvm::ArrayRef strides) { + assert(shape.size() && strides.size() && shape.size() == strides.size() && + "Shape and strides must be present and of equal size for ui64 " + "initialization."); + + llvm::SmallVector staticShape; + llvm::SmallVector staticStrides; + llvm::SmallVector dynamicShape; + llvm::SmallVector dynamicStrides; + + dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); + dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); + + auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); + auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); + + build(builder, state, tdesc, source, ValueRange({}), dynamicShape, + dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, + staticStridesAttr); +} + void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, Type tdesc, TypedValue source, llvm::ArrayRef offsets) { @@ -125,8 +188,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */, ValueRange({}) /* empty dynamic shape */, ValueRange({}) /* empty dynamic strides */, - staticOffsets /* const offsets */, {} /* empty const shape*/, - {} /* empty const strides*/); + builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */, + {} /* empty const shape*/, {} /* empty const strides*/); } void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, @@ -197,6 +260,13 @@ LogicalResult CreateNdDescOp::verify() { invalidElemTy |= memrefTy.getElementType() != getElementType(); } + if (llvm::isa(getSourceType())) { + // strides and shape must present for integer source. + if (getMixedStrides().empty() || getMixedSizes().empty()) + return emitOpError("Expecting strides and shape to be present for " + "integer source."); + } + // mismatches among shape, strides, and offsets are // already handeled by OffsetSizeAndStrideOpInterface. // So they are not check here. @@ -221,6 +291,53 @@ LogicalResult CreateNdDescOp::verify() { return success(); } +ParseResult parseOptionalDynamicIndexList( + OpAsmParser &parser, + SmallVectorImpl &values, + DenseI64ArrayAttr &integers, SmallVectorImpl *valueTypes = nullptr, + AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { + + SmallVector integerVals; + auto parseIntegerOrValue = [&]() { + OpAsmParser::UnresolvedOperand operand; + auto res = parser.parseOptionalOperand(operand); + + if (res.has_value() && succeeded(res.value())) { + values.push_back(operand); + integerVals.push_back(ShapedType::kDynamic); + if (valueTypes && parser.parseColonType(valueTypes->emplace_back())) + return failure(); + } else { + int64_t integer; + if (failed(parser.parseInteger(integer))) + return failure(); + integerVals.push_back(integer); + } + return success(); + }; + + // If the optional values are given there must be left bracket + if (parser.parseOptionalLSquare().succeeded()) { + if (parser.parseCommaSeparatedList(parseIntegerOrValue) || + parser.parseRSquare()) + return parser.emitError(parser.getNameLoc()) + << "expected a list of SSA values or integers"; + integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); + return success(); + } + + return success(); +} + +void printOptionalDynamicIndexList( + OpAsmPrinter &printer, Operation *op, OperandRange values, + ArrayRef integers, TypeRange valueTypes = TypeRange(), + AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { + + return printDynamicIndexList(printer, op, values, integers, + /*scalableFlags=*/{}, valueTypes, delimiter); +} + //===----------------------------------------------------------------------===// // XeGPU_PrefetchNdOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir index 4af7061a4f8a3..58719e75b1bde 100644 --- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir @@ -54,7 +54,7 @@ func.func @load_dynamic_source(%source: memref, // CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]] // CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]] // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] -// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] +// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32, // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> // CHECK: return %[[VEC]] diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir index d68a02b54e967..0d3da815529e3 100644 --- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir @@ -56,7 +56,7 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>, // CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]] // CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]] // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] -// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] +// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32, // CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir index c2f760b29afc4..05b41a8233e8c 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir @@ -96,7 +96,7 @@ func.func @load_dynamic_source(%source: memref, // CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]] // CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]] // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] -// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] +// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32 // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> // CHECK: return %[[VEC]] diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir index 8de6c2283b37c..2bfee03892d10 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir @@ -60,7 +60,7 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>, // CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]] // CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]] // CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] -// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] +// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32 // CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 83a98ab0622b7..eb564d55bfd51 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics // ----- -func.func @create_nd_tdesc_vc_1(%src: memref<24xf32>) { +func.func @create_nd_tdesc_1(%src: memref<24xf32>) { // expected-error@+1 {{Expecting the TensorDesc rank is not greater than the ranks of shape, strides, offsets or the memref source}} %1 = xegpu.create_nd_tdesc %src[0] : memref<24xf32> -> !xegpu.tensor_desc<8x16xf32> return @@ -9,47 +9,62 @@ func.func @create_nd_tdesc_vc_1(%src: memref<24xf32>) { // ----- -func.func @create_nd_tdesc_vc_2(%src: memref<24x32xf32>) { +func.func @create_nd_tdesc_2(%src: memref<24x32xf32>) { // expected-error@+1 {{TensorDesc should have the same element type with the source if it is a memref}} %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf16> return } // ----- -func.func @create_nd_tdesc_vc_3(%src: memref<2x24x32xf32, 3>) { +func.func @create_nd_tdesc_3(%src: memref<2x24x32xf32, 3>) { // expected-error@+1 {{SLM is only supported for 1D block tensor}} %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> return } // ----- -func.func @create_nd_tdesc_vc_4(%src: memref<2x24x32xf32, 3>) { +func.func @create_nd_tdesc_4(%src: memref<2x24x32xf32, 3>) { // expected-error@+1 {{Memory space mismatch}} %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<16xf32> return } // ----- -func.func @create_nd_tdesc_subgroup_1(%src: memref<128x128xf32>) { +func.func @create_nd_tdesc_5(%src: memref<128x128xf32>) { // expected-error@+1 {{cannot distribute [128, 128] using #xegpu.layout}} %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout> return } // ----- -func.func @create_nd_tdesc_subgroup_1(%src: memref<128x128xf32>) { +func.func @create_nd_tdesc_6(%src: memref<128x128xf32>) { // expected-error@+1 {{cannot distribute [128, 128] using #xegpu.layout}} %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout> return } // ----- -func.func @create_nd_tdesc_subgroup_1(%src: memref<128x128xf32>) { +func.func @create_nd_tdesc_7(%src: memref<128x128xf32>) { // expected-error@+1 {{cannot distribute [128, 128] using #xegpu.layout}} %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout> return } +// ----- +func.func @create_nd_tdesc_8(%src: ui64) { + // expected-error@+1 {{'xegpu.create_nd_tdesc' op Expecting strides and shape to be present for integer source}} + %1 = xegpu.create_nd_tdesc %src : ui64-> !xegpu.tensor_desc<128x128xf32> + return +} + +// ----- +func.func @create_nd_tdesc_9(%src: ui64) { + // expected-error@+1 {{expected mixed offsets rank to match mixed sizes rank}} + %1 = xegpu.create_nd_tdesc %src[0, 0] : ui64-> !xegpu.tensor_desc<128x128xf32> + return +} + + // ----- func.func @prefetch_nd_vc_1(%src: memref<24x32xf16>) { %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 3bfe1fa81aa6e..695437354cd7c 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -17,8 +17,8 @@ gpu.func @create_nd_tdesc_1(%src: memref<24x32xf32>) { gpu.func @create_nd_tdesc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) { //CHECK: %[[C:.*]] = arith.constant 1 : index %c1 = arith.constant 1 : index - // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]], [%[[arg2]], %[[arg1]]], [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32> + // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]], shape : [%[[arg2]], %[[arg1]]], strides : [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %src[%x, %y], shape:[%h, %w], strides: [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32> gpu.return } @@ -62,6 +62,47 @@ gpu.func @create_nd_tdesc_7(%src: memref<8x24x32x48x64xf32>) { } +// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>) +gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index, %y : index, %src2: memref<24x32xf32>) { + //CHECK: %[[C:.*]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + + // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.create_nd_tdesc %src2 : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + + gpu.return +} + +// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) +gpu.func @test_create_nd_tdesc_8(%src: ui64, %w : index, %h : index, %x : index, %y : index) { + + %c1 = arith.constant 1 : index + // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0], shape : [%arg2, %arg1], strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32> + %2 = xegpu.create_nd_tdesc %src, shape : [%h, %w], strides : [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32> + + gpu.return +} + +// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}}) + +gpu.func @test_create_nd_tdesc_9(%src: memref, %w : index, %h : index, %x : index, %y : index) { + + %c1 = arith.constant 1 : index + // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[%arg3, %arg4], shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %src[%x, %y], shape:[%h, %w], strides:[%w, %c1] : memref -> !xegpu.tensor_desc<8x16xf16> + + gpu.return +} + +// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}}) +gpu.func @test_create_nd_tdesc_10(%src: memref, %w : index, %h : index, %x : index, %y : index) { + %c1 = arith.constant 1 : index + // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0], shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref -> !xegpu.tensor_desc<8x16xf16> + %2 = xegpu.create_nd_tdesc %src, shape:[%h, %w], strides:[%w, %c1] : memref -> !xegpu.tensor_desc<8x16xf16> + + gpu.return +} + // CHECK: gpu.func @prefetch_nd(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @prefetch_nd(%src: memref<24x32xf16>) { // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index 3d91b2269bc4b..0bfbc4a35c03b 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -150,16 +150,16 @@ gpu.module @test { // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: ui64, %[[ARG2:[0-9a-zA-Z]+]]: index, // CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: index, %[[ARG4:[0-9a-zA-Z]+]]: index, // CHECK-SAME: %[[ARG5:[0-9a-zA-Z]+]]: index, %[[ARG6:[0-9a-zA-Z]+]]: index, %[[ARG7:[0-9a-zA-Z]+]]: index) { -// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16> +// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}], shape : [%[[ARG2]], %[[ARG3]]], strides : [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16> // CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16> -// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16> +// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}], shape : [%[[ARG2]], %[[ARG3]]], strides : [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16> // CHECK: xegpu.store_nd %[[T1]], %[[T2]] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16> gpu.module @test { gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index) { %c0 = arith.constant 0 : index - %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout> + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0], shape:[%arg2, %arg3], strides:[%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout> %1 = xegpu.load_nd %0 {layout_result_0 = #xegpu.layout} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout> -> vector<16x16xf16> - %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout> + %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0], shape:[%arg2, %arg3], strides:[%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout> xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout> gpu.return }