Skip to content

Commit 53edf12

Browse files
author
Jerry Wu
authored
[mlir] Add res() method to linalg::ContractionOpInterface (#76539)
In addition to `lhs()` and `rhs()` to return left and right operands, add `res()` to return the result value.
1 parent ddf0096 commit 53edf12

File tree

4 files changed

+60
-0
lines changed

4 files changed

+60
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

+8
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
5454
return $_op.getOperation()->getOperand(1);
5555
}]>,
5656
InterfaceMethod<
57+
/*desc=*/"Returns the result value.",
58+
/*retTy=*/"OpResult",
59+
/*methodName=*/"res",
60+
/*args=*/(ins),
61+
/*methodBody=*/[{
62+
return $_op.getOperation()->getResult(0);
63+
}]>,
64+
InterfaceMethod<
5765
/*desc=*/[{
5866
Returns whether the given op has indexing maps that correspond to a
5967
row-major matmul operation.

mlir/unittests/Dialect/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ target_link_libraries(MLIRDialectTests
88

99
add_subdirectory(ArmSME)
1010
add_subdirectory(Index)
11+
add_subdirectory(Linalg)
1112
add_subdirectory(LLVMIR)
1213
add_subdirectory(MemRef)
1314
add_subdirectory(SCF)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
add_mlir_unittest(MLIRLinalgTests
2+
LinalgInterfacesTest.cpp
3+
)
4+
target_link_libraries(MLIRLinalgTests
5+
PRIVATE
6+
MLIRLinalgDialect
7+
)
8+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//===- LinalgInterfacesTest.cpp - LinalgInterfaces unit tests ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
10+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
11+
12+
#include "gtest/gtest.h"
13+
14+
using namespace mlir;
15+
16+
class LinalgInterfacesTest : public ::testing::Test {
17+
protected:
18+
LinalgInterfacesTest() {
19+
context.getOrLoadDialect<mlir::linalg::LinalgDialect>();
20+
}
21+
22+
mlir::MLIRContext context;
23+
};
24+
25+
TEST_F(LinalgInterfacesTest, ContractionOpOperandResultAccessor) {
26+
OpBuilder b(&context);
27+
SmallVector<int64_t> lhsShape = {1, 2};
28+
SmallVector<int64_t> rhsShape = {2, 4};
29+
SmallVector<int64_t> resShape = {1, 4};
30+
auto lhs = b.create<tensor::EmptyOp>(UnknownLoc::get(&context), lhsShape,
31+
b.getF32Type());
32+
auto rhs = b.create<tensor::EmptyOp>(UnknownLoc::get(&context), rhsShape,
33+
b.getF32Type());
34+
auto out = b.create<tensor::EmptyOp>(UnknownLoc::get(&context), resShape,
35+
b.getF32Type());
36+
Operation *op = b.create<linalg::MatmulOp>(
37+
UnknownLoc::get(&context), ValueRange{lhs, rhs}, ValueRange{out});
38+
auto contractOp = llvm::cast<linalg::ContractionOpInterface>(op);
39+
40+
EXPECT_EQ(contractOp.lhs(), op->getOperand(0));
41+
EXPECT_EQ(contractOp.rhs(), op->getOperand(1));
42+
EXPECT_EQ(contractOp.res(), op->getResult(0));
43+
}

0 commit comments

Comments
 (0)