Skip to content

Commit 35b2a73

Browse files
committed
[Torch] Add support for aten.any.dims
* Added lowering to Linalg-on-Tensors * Added test to projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py
1 parent c675b2f commit 35b2a73

File tree

8 files changed

+99
-2
lines changed

8 files changed

+99
-2
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10801,6 +10801,32 @@ def Torch_AtenAnyDimOp : Torch_Op<"aten.any.dim", [
1080110801
}];
1080210802
}
1080310803

10804+
def Torch_AtenAnyDimsOp : Torch_Op<"aten.any.dims", [
10805+
AllowsTypeRefinement,
10806+
HasValueSemantics,
10807+
ReadOnly
10808+
]> {
10809+
let summary = "Generated op for `aten::any.dims : (Tensor, int[]?, bool) -> (Tensor)`";
10810+
let arguments = (ins
10811+
AnyTorchTensorType:$self,
10812+
AnyTorchOptionalListOfTorchIntType:$dim,
10813+
Torch_BoolType:$keepdim
10814+
);
10815+
let results = (outs
10816+
AnyTorchOptionalTensorType:$result
10817+
);
10818+
let hasCustomAssemblyFormat = 1;
10819+
let extraClassDefinition = [{
10820+
ParseResult AtenAnyDimsOp::parse(OpAsmParser &parser, OperationState &result) {
10821+
return parseDefaultTorchOp(parser, result, 3, 1);
10822+
}
10823+
void AtenAnyDimsOp::print(OpAsmPrinter &printer) {
10824+
printDefaultTorchOp(printer, *this, 3, 1);
10825+
}
10826+
}];
10827+
let hasFolder = 1;
10828+
}
10829+
1080410830
def Torch_AtenArangeOp : Torch_Op<"aten.arange", [
1080510831
AllowsTypeRefinement,
1080610832
HasValueSemantics,

lib/Conversion/TorchToLinalg/Reduction.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
337337
return b.create<arith::ConstantOp>(loc, b.getBoolAttr(true));
338338
}
339339

340-
if (isa<AtenAnyOp>(op)) {
340+
if (isa<AtenAnyOp, AtenAnyDimsOp>(op)) {
341341
return b.create<arith::ConstantOp>(loc, b.getBoolAttr(false));
342342
}
343343

@@ -434,7 +434,7 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
434434
Value result = payloadArgs[1];
435435
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
436436
return b.create<arith::AndIOp>(loc, self, result);
437-
} else if (isa<AtenAnyOp>(op)) {
437+
} else if (isa<AtenAnyOp, AtenAnyDimsOp>(op)) {
438438
Value elem = payloadArgs[0];
439439
Value result = payloadArgs[1];
440440
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
@@ -532,6 +532,9 @@ class ConvertReductionOp : public ConversionPattern {
532532
if (auto allOp = dyn_cast<AtenAllDimOp>(op))
533533
return computeReductionOpInfoForDimVariantOp(allOp, operands, rewriter);
534534

535+
if (auto anyOp = dyn_cast<AtenAnyDimsOp>(op))
536+
return computeReductionOpInfoForDimVariantOp(anyOp, operands, rewriter);
537+
535538
return rewriter.notifyMatchFailure(op, "not a supported reduce op");
536539
}
537540

@@ -709,6 +712,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
709712
patterns.add<ConvertAtenMinMaxDimOp<AtenMinDimOp>>(typeConverter, context);
710713
target.addIllegalOp<AtenSumOp>();
711714
target.addIllegalOp<AtenAnyOp>();
715+
target.addIllegalOp<AtenAnyDimsOp>();
712716
target.addIllegalOp<AtenAllOp>();
713717
target.addIllegalOp<AtenSumDimIntListOp>();
714718
target.addIllegalOp<AtenProdOp>();

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,22 @@ OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) {
11201120
return nullptr;
11211121
}
11221122

1123+
//===----------------------------------------------------------------------===//
1124+
// AtenAnyDimsOp
1125+
//===----------------------------------------------------------------------===//
1126+
1127+
OpFoldResult AtenAnyDimsOp::fold(FoldAdaptor adaptor) {
1128+
auto resultType = dyn_cast<ValueTensorType>(getResult().getType());
1129+
auto resultShape = resultType.toBuiltinTensor().getShape();
1130+
auto inputType = dyn_cast<ValueTensorType>(getOperand(0).getType());
1131+
auto inputShape = inputType.toBuiltinTensor().getShape();
1132+
if ((inputType.getDtype() == resultType.getDtype()) &&
1133+
(inputShape == resultShape)) {
1134+
return getSelf();
1135+
}
1136+
return {};
1137+
}
1138+
11231139
//===----------------------------------------------------------------------===//
11241140
// AtenLenTOp
11251141
//===----------------------------------------------------------------------===//

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7496,6 +7496,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
74967496
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
74977497
" return %1 : !torch.list<int>\n"
74987498
" }\n"
7499+
" func.func @\"__torch_mlir_shape_fn.aten.any.dims\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool) -> !torch.list<int> {\n"
7500+
" %none = torch.constant.none\n"
7501+
" %0 = torch.derefine %none : !torch.none to !torch.any\n"
7502+
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
7503+
" return %1 : !torch.list<int>\n"
7504+
" }\n"
74997505
" func.func @\"__torch_mlir_shape_fn.aten.all.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
75007506
" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
75017507
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
@@ -15368,6 +15374,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1536815374
" }\n"
1536915375
" return %2 : !torch.int\n"
1537015376
" }\n"
15377+
" func.func @\"__torch_mlir_dtype_fn.aten.any.dims\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool) -> !torch.int {\n"
15378+
" %int11 = torch.constant.int 11\n"
15379+
" %int0 = torch.constant.int 0\n"
15380+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15381+
" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
15382+
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
15383+
" torch.prim.If.yield %0#1 : !torch.int\n"
15384+
" } else {\n"
15385+
" torch.prim.If.yield %int11 : !torch.int\n"
15386+
" }\n"
15387+
" return %2 : !torch.int\n"
15388+
" }\n"
1537115389
" func.func @\"__torch_mlir_dtype_fn.aten.all.dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n"
1537215390
" %int11 = torch.constant.int 11\n"
1537315391
" %int0 = torch.constant.int 0\n"

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,7 @@
823823
"RandnLikeDtypeModule_basic",
824824
"RandnLikeModule_basic",
825825
"RandnModule_basic",
826+
"ReduceAnyDimsFloatModule_basic",
826827
"ReflectionPad1dModule2dInput_Right",
827828
"ReflectionPad1dModule2dInput_basic",
828829
"ReflectionPad1dModule3dInput_Left",
@@ -3778,6 +3779,7 @@
37783779
"RandnLikeModule_basic",
37793780
"RandnModule_basic",
37803781
"ReduceAllDimEmpty_basic",
3782+
"ReduceAnyDimsFloatModule_basic",
37813783
"ReduceFrobeniusNormComplexModule_basic",
37823784
"ReduceL1NormComplexModule_basic",
37833785
"ReduceL2NormComplexModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,9 @@ def aten〇one_hot〡shape(self: List[int], num_classes: int = -1) -> List[int]:
772772
def aten〇any〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]:
773773
return upstream_shape_functions.argmax(self, dim, keepdim)
774774

775+
def aten〇any〇dims〡shape(self: List[int], dim: Optional[List[int]] = None, keepdim: bool = False) -> List[int]:
776+
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
777+
775778
def aten〇all〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]:
776779
return upstream_shape_functions.argmax(self, dim, keepdim)
777780

@@ -5215,6 +5218,13 @@ def aten〇any〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim
52155218
return self_dtype
52165219
return torch.bool
52175220

5221+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
5222+
def aten〇any〇dims〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, keepdim: bool = False) -> int:
5223+
self_rank, self_dtype = self_rank_dtype
5224+
if self_dtype == torch.uint8:
5225+
return self_dtype
5226+
return torch.bool
5227+
52185228
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))
52195229
def aten〇all〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> int:
52205230
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,7 @@ def emit_with_mutating_variants(key, **kwargs):
836836
emit("aten::all.dim : (Tensor, int, bool) -> (Tensor)")
837837
emit("aten::any : (Tensor) -> (Tensor)")
838838
emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)")
839+
emit("aten::any.dims : (Tensor, int[]?, bool) -> (Tensor)", has_folder=True)
839840
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
840841
emit(
841842
"aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)"

projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,26 @@ def ReduceAnyDimFloatModule_basic(module, tu: TestUtils):
302302
module.forward(tu.rand(3, 4, 5))
303303

304304

305+
class ReduceAnyDimsFloatModule(torch.nn.Module):
306+
def __init__(self):
307+
super().__init__()
308+
309+
@export
310+
@annotate_args(
311+
[
312+
None,
313+
([-1, -1, -1], torch.float32, True),
314+
]
315+
)
316+
def forward(self, a):
317+
return torch.ops.aten.any(a, dim=[0, 1])
318+
319+
320+
@register_test_case(module_factory=lambda: ReduceAnyDimFloatModule())
321+
def ReduceAnyDimsFloatModule_basic(module, tu: TestUtils):
322+
module.forward(tu.rand(3, 4, 5))
323+
324+
305325
# ==============================================================================
306326

307327

0 commit comments

Comments
 (0)