Skip to content

Commit 0050f0e

Browse files
committed
bug(//tests): Test to reproduce FP16 accuracy issue
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 2dd1ba3 commit 0050f0e

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#include "module_test.h"
2+
3+
TEST_P(ModuleTests, FP16CompiledModuleIsClose) {
4+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
5+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
6+
for (auto in_shape : input_shapes) {
7+
auto in = at::randint(5, in_shape, {at::kCUDA});
8+
in = in.to(torch::kF16);
9+
jit_inputs_ivalues.push_back(in.clone());
10+
trt_inputs_ivalues.push_back(in.clone());
11+
}
12+
13+
auto extra_info = trtorch::ExtraInfo({input_shapes});
14+
extra_info.op_precision = torch::kF16;
15+
extra_info.strict_types = true;
16+
17+
auto trt_mod = trtorch::CompileGraph(mod, extra_info);
18+
torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
19+
std::vector<at::Tensor> trt_results;
20+
trt_results.push_back(trt_results_ivalues.toTensor());
21+
22+
mod.to(torch::kF16);
23+
torch::jit::IValue jit_results_ivalues = trtorch::tests::util::RunModuleForward(mod, jit_inputs_ivalues);
24+
std::vector<at::Tensor> jit_results;
25+
jit_results.push_back(jit_results_ivalues.toTensor());
26+
27+
for (size_t i = 0; i < trt_results.size(); i++) {
28+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]), 2e-5));
29+
}
30+
}
31+
32+
33+
INSTANTIATE_TEST_SUITE_P(CompiledModuleForwardIsCloseSuite,
34+
ModuleTests,
35+
testing::Values(
36+
PathAndInSize({"tests/modules/resnet18.jit.pt",
37+
{{1,3,224,224}}}),
38+
PathAndInSize({"tests/modules/resnet50.jit.pt",
39+
{{1,3,224,224}}}),
40+
PathAndInSize({"tests/modules/mobilenet_v2.jit.pt",
41+
{{1,3,224,224}}})));

0 commit comments

Comments
 (0)