Skip to content

Commit 014e381

Browse files
committed
feat: support aten::arange converter
Signed-off-by: inocsin <[email protected]>
1 parent 5b6bd4c commit 014e381

File tree

2 files changed

+160
-1
lines changed

2 files changed

+160
-1
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,63 @@ auto aten_registrations TRTORCH_UNUSED =
467467
LOG_WARNING("Warning from TorchScript: " << *warning);
468468
return {};
469469
},
470-
EvalOptions()});
470+
EvalOptions()})
471+
.evaluator({c10::Symbol::fromQualString("aten::arange"),
472+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
473+
// int end_scalar = 0;
474+
// auto end_scalar = ceil(args.at(n->input(0)).unwrapToScalar());
475+
int input_size = n->inputs().size();
476+
int scalar_count = 0;
477+
for (int i = 0; i < input_size; i++) {
478+
if (args.at(n->input(i)).IValue()->isScalar()) {
479+
scalar_count += 1;
480+
}
481+
}
482+
if (scalar_count == 1) {
483+
if (args.at(n->input(0)).IValue()->isInt()) {
484+
int end_scalar = args.at(n->input(0)).unwrapToInt();
485+
return torch::arange(end_scalar);
486+
} else if (args.at(n->input(0)).IValue()->isDouble()) {
487+
float end_scalar = ceil(args.at(n->input(0)).unwrapToScalar().to<float>());
488+
return torch::arange(end_scalar);
489+
}
490+
} else if (scalar_count == 2) {
491+
if (args.at(n->input(0)).IValue()->isDouble() || args.at(n->input(1)).IValue()->isDouble()) {
492+
float start_scalar = args.at(n->input(0)).unwrapToScalar().to<float>();
493+
float end_scalar = args.at(n->input(1)).unwrapToScalar().to<float>();
494+
return torch::arange(start_scalar, end_scalar);
495+
} else {
496+
int start_scalar = args.at(n->input(0)).unwrapToInt();
497+
int end_scalar = args.at(n->input(1)).unwrapToInt();
498+
return torch::arange(start_scalar, end_scalar);
499+
}
500+
} else if (scalar_count == 3) {
501+
if (args.at(n->input(0)).IValue()->isDouble() || args.at(n->input(1)).IValue()->isDouble() ||
502+
args.at(n->input(2)).IValue()->isDouble()) {
503+
float start_scalar = args.at(n->input(0)).unwrapToScalar().to<float>();
504+
float end_scalar = args.at(n->input(1)).unwrapToScalar().to<float>();
505+
float step_scalar = args.at(n->input(2)).unwrapToScalar().to<float>();
506+
return torch::arange(start_scalar, end_scalar, step_scalar);
507+
} else {
508+
int start_scalar = args.at(n->input(0)).unwrapToInt();
509+
int end_scalar = args.at(n->input(1)).unwrapToInt();
510+
int step_scalar = args.at(n->input(2)).unwrapToInt();
511+
return torch::arange(start_scalar, end_scalar, step_scalar);
512+
}
513+
} else {
514+
TRTORCH_THROW_ERROR(
515+
"Invalid input argument size for aten::arange, input argument size: " << input_size);
516+
}
517+
return {};
518+
},
519+
EvalOptions().validSchemas({
520+
R"SIG(aten::arange(Scalar end, *, int? dtype=None, int? layout=None,
521+
Device? device=None, bool? pin_memory=None) -> (Tensor))SIG",
522+
R"SIG(aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None,
523+
Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor))SIG",
524+
R"SIG(aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None,
525+
Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor))SIG",
526+
})});
471527
} // namespace
472528
} // namespace evaluators
473529
} // namespace conversion

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,107 @@ TEST(Evaluators, ZerosDataTypeEvaluatesCorrectly) {
7575
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});
7676

7777
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
78+
}
79+
80+
TEST(Evaluators, ATenArangeIntEvaluatesCorrectly) {
81+
const auto graph = R"IR(
82+
graph():
83+
%0 : int = prim::Constant[value=51]()
84+
%1 : None = prim::Constant()
85+
%2 : Tensor = aten::arange(%0, %1, %1, %1, %1)
86+
return (%2))IR";
87+
88+
auto g = std::make_shared<torch::jit::Graph>();
89+
torch::jit::parseIR(graph, &*g);
90+
91+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
92+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
93+
94+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
95+
}
96+
97+
TEST(Evaluators, ATenArangeFloatEvaluatesCorrectly) {
98+
const auto graph = R"IR(
99+
graph():
100+
%0 : float = prim::Constant[value=51.2]()
101+
%1 : None = prim::Constant()
102+
%2 : Tensor = aten::arange(%0, %1, %1, %1, %1)
103+
return (%2))IR";
104+
105+
auto g = std::make_shared<torch::jit::Graph>();
106+
torch::jit::parseIR(graph, &*g);
107+
108+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
109+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
110+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
111+
}
112+
113+
TEST(Evaluators, ATenArangeStartEndIntEvaluatesCorrectly) {
114+
const auto graph = R"IR(
115+
graph():
116+
%0 : int = prim::Constant[value=1]()
117+
%1 : int = prim::Constant[value=51]()
118+
%2 : None = prim::Constant()
119+
%3 : Tensor = aten::arange(%0, %1, %2, %2, %2, %2)
120+
return (%3))IR";
121+
122+
auto g = std::make_shared<torch::jit::Graph>();
123+
torch::jit::parseIR(graph, &*g);
124+
125+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
126+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
127+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
128+
}
129+
130+
TEST(Evaluators, ATenArangeStartEndFloatEvaluatesCorrectly) {
131+
const auto graph = R"IR(
132+
graph():
133+
%0 : float = prim::Constant[value=1.5]()
134+
%1 : float = prim::Constant[value=51.2]()
135+
%2 : None = prim::Constant()
136+
%3 : Tensor = aten::arange(%0, %1, %2, %2, %2, %2)
137+
return (%3))IR";
138+
139+
auto g = std::make_shared<torch::jit::Graph>();
140+
torch::jit::parseIR(graph, &*g);
141+
142+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
143+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
144+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
145+
}
146+
147+
TEST(Evaluators, ATenArangeStartEndStepIntEvaluatesCorrectly) {
148+
const auto graph = R"IR(
149+
graph():
150+
%0 : int = prim::Constant[value=1]()
151+
%1 : int = prim::Constant[value=51]()
152+
%2 : int = prim::Constant[value=1]()
153+
%3 : None = prim::Constant()
154+
%4 : Tensor = aten::arange(%0, %1, %2, %3, %3, %3, %3)
155+
return (%4))IR";
156+
157+
auto g = std::make_shared<torch::jit::Graph>();
158+
torch::jit::parseIR(graph, &*g);
159+
160+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
161+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
162+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
163+
}
164+
165+
TEST(Evaluators, ATenArangeStartEndStepFloatEvaluatesCorrectly) {
166+
const auto graph = R"IR(
167+
graph():
168+
%0 : float = prim::Constant[value=1.2]()
169+
%1 : float = prim::Constant[value=51.6]()
170+
%2 : float = prim::Constant[value=1.5]()
171+
%3 : None = prim::Constant()
172+
%4 : Tensor = aten::arange(%0, %1, %2, %3, %3, %3, %3)
173+
return (%4))IR";
174+
175+
auto g = std::make_shared<torch::jit::Graph>();
176+
torch::jit::parseIR(graph, &*g);
177+
178+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
179+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
180+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
78181
}

0 commit comments

Comments
 (0)