diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index fbf3f19cde0e9..777d7cfd558d2 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -54,6 +54,14 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> { return $_op.getOperation()->getOperand(1); }]>, InterfaceMethod< + /*desc=*/"Returns the result value.", + /*retTy=*/"OpResult", + /*methodName=*/"res", + /*args=*/(ins), + /*methodBody=*/[{ + return $_op.getOperation()->getResult(0); + }]>, + InterfaceMethod< /*desc=*/[{ Returns whether the given op has indexing maps that correspond to a row-major matmul operation. diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt index 2dec4ba3c001e..76b698d1d1a7b 100644 --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -8,6 +8,7 @@ target_link_libraries(MLIRDialectTests add_subdirectory(ArmSME) add_subdirectory(Index) +add_subdirectory(Linalg) add_subdirectory(LLVMIR) add_subdirectory(MemRef) add_subdirectory(SCF) diff --git a/mlir/unittests/Dialect/Linalg/CMakeLists.txt b/mlir/unittests/Dialect/Linalg/CMakeLists.txt new file mode 100644 index 0000000000000..080caab8d075e --- /dev/null +++ b/mlir/unittests/Dialect/Linalg/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_unittest(MLIRLinalgTests + LinalgInterfacesTest.cpp +) +target_link_libraries(MLIRLinalgTests + PRIVATE + MLIRLinalgDialect + ) + diff --git a/mlir/unittests/Dialect/Linalg/LinalgInterfacesTest.cpp b/mlir/unittests/Dialect/Linalg/LinalgInterfacesTest.cpp new file mode 100644 index 0000000000000..8cc4a5e37c452 --- /dev/null +++ b/mlir/unittests/Dialect/Linalg/LinalgInterfacesTest.cpp @@ -0,0 +1,43 @@ +//===- LinalgInterfacesTest.cpp - LinalgInterfaces unit tests ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "gtest/gtest.h" + +using namespace mlir; + +class LinalgInterfacesTest : public ::testing::Test { +protected: + LinalgInterfacesTest() { + context.getOrLoadDialect(); + } + + mlir::MLIRContext context; +}; + +TEST_F(LinalgInterfacesTest, ContractionOpOperandResultAccessor) { + OpBuilder b(&context); + SmallVector lhsShape = {1, 2}; + SmallVector rhsShape = {2, 4}; + SmallVector resShape = {1, 4}; + auto lhs = b.create(UnknownLoc::get(&context), lhsShape, + b.getF32Type()); + auto rhs = b.create(UnknownLoc::get(&context), rhsShape, + b.getF32Type()); + auto out = b.create(UnknownLoc::get(&context), resShape, + b.getF32Type()); + Operation *op = b.create( + UnknownLoc::get(&context), ValueRange{lhs, rhs}, ValueRange{out}); + auto contractOp = llvm::cast(op); + + EXPECT_EQ(contractOp.lhs(), op->getOperand(0)); + EXPECT_EQ(contractOp.rhs(), op->getOperand(1)); + EXPECT_EQ(contractOp.res(), op->getResult(0)); +}