Skip to content

Commit a5bc3b0

Browse files
committed
refactor(//tests): Adding more specific tests and restructuring module
fallback tests Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 532efed commit a5bc3b0

File tree

7 files changed

+241
-146
lines changed

7 files changed

+241
-146
lines changed

tests/core/lowering/BUILD

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ lowering_test(
1212
)
1313

1414
cc_test(
15-
name = "test_module_level_fallback",
16-
srcs = ["test_module_level_fallback.cpp"],
15+
name = "test_module_fallback_passes",
16+
srcs = ["test_module_fallback_passes.cpp"],
1717
deps = [
1818
"//tests/util",
1919
"//core",
@@ -63,7 +63,7 @@ test_suite(
6363
name = "lowering_tests",
6464
tests = [
6565
":test_linear_to_addmm",
66-
":test_module_level_fallback",
66+
":test_module_fallback_passes",
6767
":test_operator_aliasing_pass",
6868
":test_remove_contiguous_pass",
6969
":test_remove_detach_pass",
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#include <string>
2+
#include <unordered_set>
3+
#include "core/compiler.h"
4+
#include "core/lowering/lowering.h"
5+
#include "gtest/gtest.h"
6+
#include "tests/util/util.h"
7+
#include "torch/script.h"
8+
#include "core/lowering/passes/passes.h"
9+
#include "torch/csrc/jit/passes/freeze_module.h"
10+
11+
TEST(Lowering, NotateModuleForFallbackWorksCorrectly) {
12+
torch::jit::script::Module mod;
13+
try {
14+
mod = torch::jit::load("tests/modules/module_fallback_scripted.jit.pt");
15+
} catch (const c10::Error& e) {
16+
std::cerr << "error loading the model\n";
17+
ASSERT_TRUE(false);
18+
}
19+
20+
std::unordered_set<std::string> mods_to_mark;
21+
mods_to_mark.insert("ModuleFallbackSub");
22+
23+
trtorch::core::lowering::passes::NotateModuleForFallback(mod, "", "forward", mods_to_mark);
24+
25+
auto g = mod.get_method("forward").graph();
26+
auto nodes = g->block()->nodes();
27+
28+
bool seen_enter = false;
29+
int64_t enter_count = 0;
30+
int64_t exit_count = 0;
31+
int64_t intermediate_nodes = 0;
32+
for (auto it = nodes.begin(); it != nodes.end(); it++) {
33+
auto n = *it;
34+
if (n->kind() == torch::jit::prim::Enter) {
35+
enter_count++;
36+
auto internal_n = *(++it);
37+
ASSERT_TRUE(internal_n->kind() != torch::jit::prim::Exit);
38+
intermediate_nodes++;
39+
auto end = *(++it);
40+
ASSERT_TRUE(end->kind() == torch::jit::prim::Exit);
41+
exit_count++;
42+
seen_enter = true;
43+
}
44+
}
45+
ASSERT_TRUE(seen_enter);
46+
ASSERT_TRUE(enter_count == 1);
47+
ASSERT_TRUE(intermediate_nodes == 1);
48+
ASSERT_TRUE(exit_count == 1);
49+
}
50+
51+
TEST(Lowering, MarkNodesForFallbackWorksCorrectly) {
52+
torch::jit::script::Module mod;
53+
try {
54+
mod = torch::jit::load("tests/modules/module_fallback_scripted.jit.pt");
55+
} catch (const c10::Error& e) {
56+
std::cerr << "error loading the model\n";
57+
ASSERT_TRUE(false);
58+
}
59+
60+
std::unordered_set<std::string> mods_to_mark;
61+
mods_to_mark.insert("ModuleFallbackSub");
62+
63+
trtorch::core::lowering::passes::NotateModuleForFallback(mod, "", "forward", mods_to_mark);
64+
auto mod_ = torch::jit::freeze_module(mod);
65+
auto g = mod_.get_method("forward").graph();
66+
trtorch::core::lowering::passes::MarkNodesForFallback(g, true);
67+
auto nodes = g->block()->nodes();
68+
69+
int64_t num_marked_nodes = 0;
70+
71+
for (auto n : nodes) {
72+
auto has_compile_attribute = n->hasAttribute(c10::Symbol::attr("to_compile"));
73+
if (has_compile_attribute && n->i(c10::Symbol::attr("to_compile")) == (int64_t) false) {
74+
num_marked_nodes++;
75+
}
76+
}
77+
78+
ASSERT_TRUE(num_marked_nodes == 2);
79+
}
80+
81+
TEST(Lowering, LowerAndPartitionSimpleModuleFallbackCorrectly) {
82+
torch::jit::script::Module mod;
83+
try {
84+
mod = torch::jit::load("tests/modules/module_fallback_scripted.jit.pt");
85+
} catch (const c10::Error& e) {
86+
std::cerr << "error loading the model\n";
87+
ASSERT_TRUE(false);
88+
}
89+
90+
const std::vector<std::vector<int64_t>> input_shapes = {{1, 1, 16, 16}};
91+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
92+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
93+
for (auto in_shape : input_shapes) {
94+
auto in = at::randint(5, in_shape, {at::kCUDA});
95+
jit_inputs_ivalues.push_back(in.clone());
96+
trt_inputs_ivalues.push_back(in.clone());
97+
}
98+
99+
std::vector<trtorch::core::ir::Input> input_ranges{trtorch::core::ir::Input({1, 1, 16, 16})};
100+
trtorch::core::CompileSpec cfg(input_ranges);
101+
cfg.partition_info.enabled = true;
102+
cfg.lower_info.forced_fallback_modules.push_back("ModuleFallbackSub");
103+
104+
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
105+
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
106+
107+
auto g = trt_mod.get_method("forward").graph();
108+
auto nodes = g->block()->nodes();
109+
std::size_t curr_node = 0;
110+
for (const auto n : nodes) {
111+
if (curr_node == 5) {
112+
ASSERT_TRUE(n->kind() == torch::jit::aten::conv2d);
113+
ASSERT_TRUE(n->i(c10::Symbol::attr("to_compile")) == (int64_t) false);
114+
} else if (curr_node == 6) {
115+
ASSERT_TRUE(n->kind() == torch::jit::aten::relu);
116+
ASSERT_TRUE(n->i(c10::Symbol::attr("to_compile")) == (int64_t) false);
117+
} else if (curr_node == 7) {
118+
ASSERT_TRUE(n->kind() == torch::jit::prim::GetAttr);
119+
ASSERT_TRUE(n->s(c10::Symbol::attr("name")).find("trt_engine") != std::string::npos);
120+
}
121+
curr_node++;
122+
}
123+
124+
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
125+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
126+
}
127+

tests/core/lowering/test_module_level_fallback.cpp

Lines changed: 0 additions & 142 deletions
This file was deleted.

tests/cpp/BUILD

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ test_suite(
1616
":test_modules_as_engines",
1717
":test_multiple_registered_engines",
1818
":test_serialization",
19+
":test_module_fallback"
1920
],
2021
)
2122

@@ -27,6 +28,7 @@ test_suite(
2728
":test_modules_as_engines",
2829
":test_multiple_registered_engines",
2930
":test_serialization",
31+
":test_module_fallback"
3032
],
3133
)
3234

@@ -79,6 +81,22 @@ cc_test(
7981
],
8082
)
8183

84+
cc_test(
85+
name = "test_module_fallback",
86+
srcs = ["test_module_fallback.cpp"],
87+
data = [
88+
"//tests/modules:jit_models",
89+
],
90+
deps = [
91+
"//cpp/api:trtorch",
92+
"//tests/util",
93+
"@googletest//:gtest_main",
94+
] + select({
95+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
96+
"//conditions:default": ["@libtorch//:libtorch"],
97+
})
98+
)
99+
82100
cc_test(
83101
name = "test_compiled_modules",
84102
srcs = ["test_compiled_modules.cpp"],

tests/cpp/cpp_api_test.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class CppAPITests : public testing::TestWithParam<PathAndInSize> {
1818
mod = torch::jit::load(path);
1919
} catch (const c10::Error& e) {
2020
std::cerr << "error loading the model\n";
21-
return;
21+
ASSERT_TRUE(false);
2222
}
2323
input_shapes = std::get<1>(params);
2424
threshold = std::get<2>(params);

0 commit comments

Comments
 (0)