Skip to content

Conversation

dchauhan-arm
Copy link
Contributor

Currently, the MaxPool2D operation in the TOSA MLIR dialect does not accept half-precision Fp16 and Bf16 tensors, converse to what is stated in the TOSA Specification.

This patch fixes the verifier to accept the two datatypes for input/output tensors, and adds related LIT test cases in Tosa/ops.mlir

Currently, the MaxPool2D operation in the TOSA MLIR dialect does not
accept half-precision Fp16 and Bf16 tensors, converse to what is stated
in the [TOSA Specification](https://www.mlplatform.org/tosa/tosa_spec.html#_max_pool2d).

This patch fixes the verifier to accept the two datatypes for
input/output tensors, and adds related LIT test cases in Tosa/ops.mlir
@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2023

@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Dhruv Chauhan (dchauhan-arm)

Changes

Currently, the MaxPool2D operation in the TOSA MLIR dialect does not accept half-precision Fp16 and Bf16 tensors, converse to what is stated in the TOSA Specification.

This patch fixes the verifier to accept the two datatypes for input/output tensors, and adds related LIT test cases in Tosa/ops.mlir


Full diff: https://github.com/llvm/llvm-project/pull/69332.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+1-1)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+16-2)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 4214bb57563285c..ee8f52deadbd152 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -691,7 +691,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
 
     // Determine what the initial value needs to be for the max pool op.
     TypedAttr initialAttr;
-    if (resultETy.isF32())
+    if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16())
       initialAttr = rewriter.getFloatAttr(
           resultETy, APFloat::getLargest(
                          cast<FloatType>(resultETy).getFloatSemantics(), true));
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e62bea515d06baa..8ce8fb73f29a504 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -97,12 +97,26 @@ func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -
 }
 
 // -----
-// CHECK-LABEL: max_pool2d
-func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+// CHECK-LABEL: max_pool2d_f32
+func.func @test_max_pool2d_f32(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
   %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
   return %0 : tensor<1x32x32x8xf32>
 }
 
+// -----
+// CHECK-LABEL: max_pool2d_bf16
+func.func @test_max_pool2d_bf16(%arg0: tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16> {
+  %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16>
+  return %0 : tensor<1x32x32x8xbf16>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_f16
+func.func @test_max_pool2d_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16> {
+  %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16>
+  return %0 : tensor<1x32x32x8xf16>
+}
+
 // -----
 // CHECK-LABEL: rfft2d
 func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@GeorgeARM GeorgeARM merged commit c926291 into llvm:main Oct 18, 2023
ljfitz pushed a commit to Xilinx/llvm-project that referenced this pull request Feb 22, 2024
Currently, the MaxPool2D operation in the TOSA MLIR dialect does not
accept half-precision Fp16 and Bf16 tensors, converse to what is stated
in the [TOSA
Specification](https://www.mlplatform.org/tosa/tosa_spec.html#_max_pool2d).

This patch fixes the verifier to accept the two datatypes for
input/output tensors, and adds related LIT test cases in Tosa/ops.mlir
ttjost added a commit to Xilinx/llvm-project that referenced this pull request Feb 22, 2024
[MLIR][TOSA] Fix f16/bf16 support for MaxPool2D (llvm#69332)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants