From f5d1818634b782645692d71f2b18dbd2a8d27b04 Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Fri, 1 Sep 2023 09:43:39 +0000 Subject: [PATCH] [mlir][ArmSME] Use ArmSMETypeConverter for all VectorToLLVM patterns LLVMTypeConverter::convertVectorType asserts on n-D scalable vectors to prevent generating illegal LLVM IR, since LLVM doesn't support arrays of scalable vectors. The ArmSMETypeConverter disables this conversion, but is only used for ArmSME dialect conversions that rewrite higher-level custom ArmSME ops to intrinsics. This is problematic if we want to lower Vector ops directly to ArmSME intrinsics, as the assert fires for ops that have dialect conversion patterns (defined in ConvertVectorToLLVMPass, e.g. populateVectorToLLVMConversionPatterns) that use the LLVMTypeConverter. There are three options to get around this: 1. Avoid the generic VectorToLLVM dialect conversion patterns (and thus the assert) altogether by first lowering Vector ops to custom ArmSME ops. 2. Disable the generic VectorToLLVM dialect conversion patterns if ArmSME is enabled. 3. Disable n-D scalable vector type conversion for all dialect conversion patterns if SME is enabled. Option 1 is already done for several Vector ops such as vector.load and vector.store as part of ConvertVectorToArmSME, but where possible we'd like to avoid bloating the ArmSME dialect by having to mirror all the Vector ops. Option 2 is undesirable as the generic conversions should only be disabled for the 2-d scalable vector types the ArmSME patterns apply to. We'd still like Vector ops with other types to get lowered via the default path when ArmSME is enabled. This patch therefore implements option 3 to use the ArmSMETypeConverter for all VectorToLLVM conversion patterns when ArmSME is enabled. --- .../Conversion/LLVMCommon/TypeConverter.h | 7 +++--- .../mlir/Dialect/ArmSME/Transforms/Passes.h | 13 +++++++++++ .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 23 +++++++++++-------- .../ArmSME/Transforms/ArmSMETypeConverter.cpp | 15 ++++++------ .../VectorToLLVM/vector-to-llvm.mlir | 1 + 5 files changed, 40 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h index ed174699314e8..43db7987e650a 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -238,14 +238,15 @@ class LLVMTypeConverter : public TypeConverter { /// Convert a memref type to a bare pointer to the memref element type. Type convertMemRefToBarePtr(BaseMemRefType type) const; - /// Convert a 1D vector type into an LLVM vector type. - Type convertVectorType(VectorType type) const; - /// Options for customizing the llvm lowering. LowerToLLVMOptions options; /// Data layout analysis mapping scopes to layouts active in them. const DataLayoutAnalysis *dataLayoutAnalysis; + +protected: + /// Convert a 1D vector type into an LLVM vector type. + Type convertVectorType(VectorType type) const; }; /// Callback to convert function argument types. It converts a MemRef function diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h index ab5c179f2dd77..ad3c010816fa3 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h @@ -43,6 +43,19 @@ std::unique_ptr createTileAllocationPass(); class ArmSMETypeConverter : public LLVMTypeConverter { public: ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options); + +protected: + /// Convert an n-D vector type to an LLVM vector type. + /// + /// Disables type conversion of legal 2-D scalable vector types such as + /// `vector<[16]x[16]xi8>` for ArmSME, since LLVM does not support arrays of + /// scalable vectors and the LLVM type converter asserts on such types to + /// prevent generation of illegal LLVM IR. When lowering to ArmSME these types + /// should be eliminated before lowering to LLVM. + /// + /// Types unrelated to ArmSME are converted by + /// `LLVMTypeConverter::convertVectorType`. + Type convertVectorType(VectorType type) const; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 04570a750822a..c534ef6e408b8 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -83,21 +83,26 @@ void LowerVectorToLLVMPass::runOnOperation() { // Convert to the LLVM IR dialect. LowerToLLVMOptions options(&getContext()); options.useOpaquePointers = useOpaquePointers; - LLVMTypeConverter converter(&getContext(), options); + + LLVMTypeConverter *converter; + if (armSME) + converter = new arm_sme::ArmSMETypeConverter(&getContext(), options); + else + converter = new LLVMTypeConverter(&getContext(), options); + RewritePatternSet patterns(&getContext()); populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices); populateVectorTransferLoweringPatterns(patterns); - populateVectorToLLVMMatrixConversionPatterns(converter, patterns); + populateVectorToLLVMMatrixConversionPatterns(*converter, patterns); populateVectorToLLVMConversionPatterns( - converter, patterns, reassociateFPReductions, force32BitVectorIndices); - populateVectorToLLVMMatrixConversionPatterns(converter, patterns); + *converter, patterns, reassociateFPReductions, force32BitVectorIndices); + populateVectorToLLVMMatrixConversionPatterns(*converter, patterns); // Architecture specific augmentations. LLVMConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalDialect(); target.addLegalOp(); - arm_sme::ArmSMETypeConverter armSMEConverter(&getContext(), options); if (armNeon) { // TODO: we may or may not want to include in-dialect lowering to @@ -107,19 +112,19 @@ void LowerVectorToLLVMPass::runOnOperation() { } if (armSVE) { configureArmSVELegalizeForExportTarget(target); - populateArmSVELegalizeForLLVMExportPatterns(converter, patterns); + populateArmSVELegalizeForLLVMExportPatterns(*converter, patterns); } if (armSME) { configureArmSMELegalizeForExportTarget(target); - populateArmSMELegalizeForLLVMExportPatterns(armSMEConverter, patterns); + populateArmSMELegalizeForLLVMExportPatterns(*converter, patterns); } if (amx) { configureAMXLegalizeForExportTarget(target); - populateAMXLegalizeForLLVMExportPatterns(converter, patterns); + populateAMXLegalizeForLLVMExportPatterns(*converter, patterns); } if (x86Vector) { configureX86VectorLegalizeForExportTarget(target); - populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns); + populateX86VectorLegalizeForLLVMExportPatterns(*converter, patterns); } if (failed( diff --git a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp index 1cefc220ecf10..65da2a7a75d29 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp @@ -7,16 +7,17 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/ArmSME/Transforms/Passes.h" +#include "mlir/Dialect/ArmSME/Utils/Utils.h" using namespace mlir; arm_sme::ArmSMETypeConverter::ArmSMETypeConverter( MLIRContext *ctx, const LowerToLLVMOptions &options) : LLVMTypeConverter(ctx, options) { - // Disable LLVM type conversion for vectors. This is to prevent 2-d scalable - // vectors (common in the context of ArmSME), e.g. - // `vector<[16]x[16]xi8>`, - // entering the LLVM Type converter. LLVM does not support arrays of scalable - // vectors, but in the case of SME such types are effectively eliminated when - // emitting ArmSME LLVM IR intrinsics. - addConversion([&](VectorType type) { return type; }); + addConversion([&](VectorType type) { return convertVectorType(type); }); +} + +Type arm_sme::ArmSMETypeConverter::convertVectorType(VectorType type) const { + if (arm_sme::isValidSMETileVectorType(type)) + return type; + return LLVMTypeConverter::convertVectorType(type); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 514594240d22a..3f897fbf01b7b 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -convert-vector-to-llvm='use-opaque-pointers=1' -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-vector-to-llvm='use-opaque-pointers=1 enable-arm-sme' -split-input-file | FileCheck %s func.func @bitcast_f32_to_i32_vector_0d(%input: vector) -> vector {