diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 3dbc8e9916df6..7e77067936743 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/ArmSME/Utils/Utils.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -877,6 +878,24 @@ struct ConvertArmSMEToLLVMPass if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); + + // Walk the function and fail if there are unexpected operations on SME + // tile types after conversion. + function->walk([&](Operation *op) { + // These ops are legal post conversion, skip these. + if (isa(op) || + !op->isRegistered()) + return; + auto isSMETileType = [](Type type) { + return arm_sme::isValidSMETileVectorType(type); + }; + if (llvm::any_of(op->getResultTypes(), isSMETileType) || + llvm::any_of(op->getOperandTypes(), isSMETileType)) { + op->emitOpError("unexpected operation with SME tile type after " + "conversion to LLVM"); + signalPassFailure(); + } + }); } }; diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir index 14b1f323da3a2..ef85f3d069d74 100644 --- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir +++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir @@ -629,18 +629,18 @@ func.func @arm_sme_streaming_vl_double_words() -> index { // CHECK-LABEL: arm_sme_fmopa_2way_f16f16_to_f32 // CHECK: "arm_sme.intr.mopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> () -func.func @arm_sme_fmopa_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> { +func.func @arm_sme_fmopa_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) { %result = arm_sme.fmopa_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> - return %result : vector<[4]x[4]xf32> + "test.some_use"(%result) : (vector<[4]x[4]xf32>) -> () } // ----- // CHECK-LABEL: arm_sme_fmopa_2way_bf16bf16_to_f32 // CHECK: "arm_sme.intr.mopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> () -func.func @arm_sme_fmopa_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> { +func.func @arm_sme_fmopa_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) { %result = arm_sme.fmopa_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32> - return %result : vector<[4]x[4]xf32> + "test.some_use"(%result) : (vector<[4]x[4]xf32>) -> () } //===----------------------------------------------------------------------===// @@ -651,18 +651,18 @@ func.func @arm_sme_fmopa_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: ve // CHECK-LABEL: arm_sme_fmops_2way_f16f16_to_f32 // CHECK: "arm_sme.intr.mops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> () -func.func @arm_sme_fmops_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> { +func.func @arm_sme_fmops_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) { %result = arm_sme.fmops_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> - return %result : vector<[4]x[4]xf32> + "test.some_use"(%result) : (vector<[4]x[4]xf32>) -> () } // ----- // CHECK-LABEL: arm_sme_fmops_2way_bf16bf16_to_f32 // CHECK: "arm_sme.intr.mops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> () -func.func @arm_sme_fmops_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> { +func.func @arm_sme_fmops_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) { %result = arm_sme.fmops_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32> - return %result : vector<[4]x[4]xf32> + "test.some_use"(%result) : (vector<[4]x[4]xf32>) -> () } //===----------------------------------------------------------------------===// @@ -673,9 +673,9 @@ func.func @arm_sme_fmops_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: ve // CHECK-LABEL: arm_sme_smopa_2way_i16i16_to_i32 // CHECK: "arm_sme.intr.smopa.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () -func.func @arm_sme_smopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> { +func.func @arm_sme_smopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) { %result = arm_sme.smopa_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> - return %result : vector<[4]x[4]xi32> + "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> () } //===----------------------------------------------------------------------===// @@ -686,9 +686,9 @@ func.func @arm_sme_smopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vecto // CHECK-LABEL: arm_sme_smops_2way_i16i16_to_i32 // CHECK: "arm_sme.intr.smops.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () -func.func @arm_sme_smops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> { +func.func @arm_sme_smops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) { %result = arm_sme.smops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> - return %result : vector<[4]x[4]xi32> + "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> () } //===----------------------------------------------------------------------===// @@ -699,9 +699,10 @@ func.func @arm_sme_smops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vecto // CHECK-LABEL: arm_sme_umopa_2way_i16i16_to_i32 // CHECK: "arm_sme.intr.umopa.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () -func.func @arm_sme_umopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> { +func.func @arm_sme_umopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) { %result = arm_sme.umopa_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> - return %result : vector<[4]x[4]xi32> + "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> () + return } //===----------------------------------------------------------------------===// @@ -712,9 +713,10 @@ func.func @arm_sme_umopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vecto // CHECK-LABEL: arm_sme_umops_2way_i16i16_to_i32 // CHECK: "arm_sme.intr.umops.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () -func.func @arm_sme_umops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> { +func.func @arm_sme_umops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) { %result = arm_sme.umops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> - return %result : vector<[4]x[4]xi32> + "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> () + return } //===----------------------------------------------------------------------===// @@ -725,18 +727,20 @@ func.func @arm_sme_umops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vecto // CHECK-LABEL: arm_sme_smopa_4way_i8i8_to_i32 // CHECK: "arm_sme.intr.smopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () -func.func @arm_sme_smopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> { +func.func @arm_sme_smopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) { %result = arm_sme.smopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> - return %result : vector<[4]x[4]xi32> + "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> () + return } // ----- // CHECK-LABEL: arm_sme_smopa_4way_i16i16_to_i64 // CHECK: "arm_sme.intr.smopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () -func.func @arm_sme_smopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> { +func.func @arm_sme_smopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) { %result = arm_sme.smopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> - return %result : vector<[2]x[2]xi64> + "test.some_use"(%result) : (vector<[2]x[2]xi64>) -> () + return } //===----------------------------------------------------------------------===// @@ -747,18 +751,20 @@ func.func @arm_sme_smopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vecto // CHECK-LABEL: arm_sme_smops_4way_i8i8_to_i32 // CHECK: "arm_sme.intr.smops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () -func.func @arm_sme_smops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> { +func.func @arm_sme_smops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) { %result = arm_sme.smops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> - return %result : vector<[4]x[4]xi32> + "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> () + return } // ----- // CHECK-LABEL: arm_sme_smops_4way_i16i16_to_i64 // CHECK: "arm_sme.intr.smops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () -func.func @arm_sme_smops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> { +func.func @arm_sme_smops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) { %result = arm_sme.smops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> - return %result : vector<[2]x[2]xi64> + "test.some_use"(%result) : (vector<[2]x[2]xi64>) -> () + return } //===----------------------------------------------------------------------===// @@ -769,18 +775,20 @@ func.func @arm_sme_smops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vecto // CHECK-LABEL: arm_sme_umopa_4way_i8i8_to_i32 // CHECK: "arm_sme.intr.umopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () -func.func @arm_sme_umopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> { +func.func @arm_sme_umopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) { %result = arm_sme.umopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> - return %result : vector<[4]x[4]xi32> + "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> () + return } // ----- // CHECK-LABEL: arm_sme_umopa_4way_i16i16_to_i64 // CHECK: "arm_sme.intr.umopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () -func.func @arm_sme_umopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> { +func.func @arm_sme_umopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) { %result = arm_sme.umopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> - return %result : vector<[2]x[2]xi64> + "test.some_use"(%result) : (vector<[2]x[2]xi64>) -> () + return } //===----------------------------------------------------------------------===// @@ -791,18 +799,20 @@ func.func @arm_sme_umopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vecto // CHECK-LABEL: arm_sme_umops_4way_i8i8_to_i32 // CHECK: "arm_sme.intr.umops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () -func.func @arm_sme_umops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> { +func.func @arm_sme_umops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) { %result = arm_sme.umops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> - return %result : vector<[4]x[4]xi32> + "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> () + return } // ----- // CHECK-LABEL: arm_sme_umops_4way_i16i16_to_i64 // CHECK: "arm_sme.intr.umops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () -func.func @arm_sme_umops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> { +func.func @arm_sme_umops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) { %result = arm_sme.umops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> - return %result : vector<[2]x[2]xi64> + "test.some_use"(%result) : (vector<[2]x[2]xi64>) -> () + return } //===----------------------------------------------------------------------===// @@ -813,18 +823,20 @@ func.func @arm_sme_umops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vecto // CHECK-LABEL: arm_sme_sumopa_4way_i8i8_to_i32 // CHECK: "arm_sme.intr.sumopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () -func.func @arm_sme_sumopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> { +func.func @arm_sme_sumopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) { %result = arm_sme.sumopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> - return %result : vector<[4]x[4]xi32> + "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> () + return } // ----- // CHECK-LABEL: arm_sme_sumopa_4way_i16i16_to_i64 // CHECK: "arm_sme.intr.sumopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () -func.func @arm_sme_sumopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> { +func.func @arm_sme_sumopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) { %result = arm_sme.sumopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> - return %result : vector<[2]x[2]xi64> + "test.some_use"(%result) : (vector<[2]x[2]xi64>) -> () + return } //===----------------------------------------------------------------------===// @@ -835,18 +847,20 @@ func.func @arm_sme_sumopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vect // CHECK-LABEL: arm_sme_sumops_4way_i8i8_to_i32 // CHECK: "arm_sme.intr.sumops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () -func.func @arm_sme_sumops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> { +func.func @arm_sme_sumops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) { %result = arm_sme.sumops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> - return %result : vector<[4]x[4]xi32> + "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> () + return } // ----- // CHECK-LABEL: arm_sme_sumops_4way_i16i16_to_i64 // CHECK: "arm_sme.intr.sumops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () -func.func @arm_sme_sumops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> { +func.func @arm_sme_sumops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) { %result = arm_sme.sumops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> - return %result : vector<[2]x[2]xi64> + "test.some_use"(%result) : (vector<[2]x[2]xi64>) -> () + return } //===----------------------------------------------------------------------===// @@ -857,18 +871,20 @@ func.func @arm_sme_sumops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vect // CHECK-LABEL: arm_sme_usmopa_4way_i8i8_to_i32 // CHECK: "arm_sme.intr.usmopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () -func.func @arm_sme_usmopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> { - %reuslt = arm_sme.usmopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> - return %reuslt : vector<[4]x[4]xi32> +func.func @arm_sme_usmopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) { + %result = arm_sme.usmopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> + "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> () + return } // ----- // CHECK-LABEL: arm_sme_usmopa_4way_i16i16_to_i64 // CHECK: "arm_sme.intr.usmopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () -func.func @arm_sme_usmopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> { - %reuslt = arm_sme.usmopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> - return %reuslt : vector<[2]x[2]xi64> +func.func @arm_sme_usmopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) { + %result = arm_sme.usmopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> + "test.some_use"(%result) : (vector<[2]x[2]xi64>) -> () + return } //===----------------------------------------------------------------------===// @@ -879,16 +895,45 @@ func.func @arm_sme_usmopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vect // CHECK-LABEL: arm_sme_usmops_4way_i8i8_to_i32 // CHECK: "arm_sme.intr.usmops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () -func.func @arm_sme_usmops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> { - %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> - return %reuslt : vector<[4]x[4]xi32> +func.func @arm_sme_usmops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) { + %result = arm_sme.usmops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> + "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> () + return } // ----- // CHECK-LABEL: arm_sme_usmops_4way_i16i16_to_i64 // CHECK: "arm_sme.intr.usmops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () -func.func @arm_sme_usmops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> { - %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> - return %reuslt : vector<[2]x[2]xi64> +func.func @arm_sme_usmops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) { + %result = arm_sme.usmops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> + "test.some_use"(%result) : (vector<[2]x[2]xi64>) -> () + return +} + +//===----------------------------------------------------------------------===// +// Operations on SME tile types allowed after conversion +//===----------------------------------------------------------------------===// + +// ----- + +// The following operations on SME tile types are permitted after conversion: +// +// - arm_sme.copy_tile +// - arm_sme.get_tile +// - cf.br +// - any unregistered op such as 'test.some_use'. +// +// this test verifies this. Conversion will fail for operations with SME tile +// types not in this list, this is tested in 'unsupported.mlir'. + +func.func @ops_on_tiles_legal_post_conversion(%ub : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %tile = arm_sme.get_tile : vector<[4]x[4]xf32> + %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32> + cf.br ^bb1(%copy : vector<[4]x[4]xf32>) +^bb1(%x : vector<[4]x[4]xf32>): + "test.some_use"(%x) : (vector<[4]x[4]xf32>) -> () + return } diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir index 2c3868d7f25cb..91c1b92b01224 100644 --- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir +++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir @@ -141,7 +141,7 @@ func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: m // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %loadSlice = arm_sme.load_tile_slice %memref[%c0, %c0], %mask, %tile, %c0 : memref, vector<[4]xi1>, vector<[4]x[4]xf32> "test.some_use"(%useAllTiles) : (vector<[16]x[16]xi8>) -> () - return %loadSlice : vector<[4]x[4]xf32> + "test.some_use"(%loadSlice) : (vector<[4]x[4]xf32>) -> () } // AFTER-TILE-ALLOC-LABEL: @very_excessive_spills // AFTER-TILE-ALLOC: arm_sme.load_tile_slice diff --git a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir index a62ca080ab8d9..b2c41f284fb86 100644 --- a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir +++ b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm))" -verify-diagnostics +// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm))" -verify-diagnostics -split-input-file //===----------------------------------------------------------------------===// // arm_sme.outerproduct @@ -6,9 +6,21 @@ func.func @arm_sme_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>) { %acc = arm_sme.get_tile : vector<[16]x[16]xi8> + // expected-error@below {{unexpected operation with SME tile type after conversion to LLVM}} // expected-error@+2 {{failed to legalize operation 'arm_sme.outerproduct'}} // expected-error@+1 {{unsupported type}} %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc) : vector<[16]xi8>, vector<[16]xi8> "test.some_use"(%0) : (vector<[16]x[16]xi8>) -> () } +//===----------------------------------------------------------------------===// +// Unsupported operations on SME tile types +//===----------------------------------------------------------------------===// + +// ----- + +func.func @unsupported_arith_op(%a : vector<[4]x[4]xf32>, %b : vector<[4]x[4]xf32>) { + // expected-error@below {{unexpected operation with SME tile type after conversion to LLVM}} + %0 = arith.addf %a, %b : vector<[4]x[4]xf32> + "test.some_use"(%0) : (vector<[4]x[4]xf32>) -> () +}