Skip to content

Commit 6aaba3b

Browse files
committed
feat(aten::sqrt): Adding support for sqrt evaluators
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 50f012e commit 6aaba3b

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,25 @@ auto aten_registrations TRTORCH_UNUSED =
540540
"aten::floor.int(int a) -> (int)",
541541
"aten::floor.float(float a) -> (int)",
542542
})})
543+
.evaluator({c10::Symbol::fromQualString("aten::sqrt"),
544+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
545+
if (args.at(n->input(0)).IValue()->isInt()) {
546+
auto a = args.at(n->input(0)).unwrapToInt();
547+
return std::sqrt(static_cast<double>(a));
548+
} else if (args.at(n->input(0)).IValue()->isDouble()) {
549+
auto a = args.at(n->input(0)).unwrapToDouble();
550+
return std::sqrt(a);
551+
} else {
552+
TRTORCH_THROW_ERROR(
553+
"Unimplemented data type for aten::sqrt evaluator: "
554+
<< args.at(n->input(0)).IValue()->type()->str());
555+
return {};
556+
}
557+
},
558+
EvalOptions().validSchemas({
559+
"aten::sqrt.int(int a) -> (float)",
560+
"aten::sqrt.float(float a) -> (float)",
561+
})})
543562
.evaluator({c10::Symbol::fromQualString("aten::warn"),
544563
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
545564
auto warning = args.at(n->input(0)).IValue();

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,37 @@ TEST(Evaluators, ATenAppendWithITensorAndTensorEvaluatesCorrectly) {
357357
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
358358
}
359359

360+
TEST(Evaluators, SqrtIntEvaluatesCorrectly) {
361+
const auto graph = R"IR(
362+
graph():
363+
%1 : int = prim::Constant[value=9]()
364+
%2 : float = aten::sqrt(%1)
365+
return (%2))IR";
366+
367+
auto g = std::make_shared<torch::jit::Graph>();
368+
torch::jit::parseIR(graph, g.get());
369+
370+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
371+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
372+
373+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
374+
}
375+
376+
TEST(Evaluators, SqrtFloatEvaluatesCorrectly) {
377+
const auto graph = R"IR(
378+
graph():
379+
%1 : float = prim::Constant[value=9.0]()
380+
%2 : float = aten::sqrt(%1)
381+
return (%2))IR";
382+
383+
auto g = std::make_shared<torch::jit::Graph>();
384+
torch::jit::parseIR(graph, g.get());
385+
386+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
387+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
388+
389+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
390+
}
360391
TEST(Evaluators, ATenCloneEvaluatesCorrectly) {
361392
const auto graph = R"IR(
362393
graph(%0 : Tensor):

0 commit comments

Comments
 (0)