@@ -75,4 +75,107 @@ TEST(Evaluators, ZerosDataTypeEvaluatesCorrectly) {
75
75
auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {in});
76
76
77
77
ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
78
+ }
79
+
80
+ TEST (Evaluators, ATenArangeIntEvaluatesCorrectly) {
81
+ const auto graph = R"IR(
82
+ graph():
83
+ %0 : int = prim::Constant[value=51]()
84
+ %1 : None = prim::Constant()
85
+ %2 : Tensor = aten::arange(%0, %1, %1, %1, %1)
86
+ return (%2))IR" ;
87
+
88
+ auto g = std::make_shared<torch::jit::Graph>();
89
+ torch::jit::parseIR (graph, &*g);
90
+
91
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
92
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
93
+
94
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
95
+ }
96
+
97
+ TEST (Evaluators, ATenArangeFloatEvaluatesCorrectly) {
98
+ const auto graph = R"IR(
99
+ graph():
100
+ %0 : float = prim::Constant[value=51.2]()
101
+ %1 : None = prim::Constant()
102
+ %2 : Tensor = aten::arange(%0, %1, %1, %1, %1)
103
+ return (%2))IR" ;
104
+
105
+ auto g = std::make_shared<torch::jit::Graph>();
106
+ torch::jit::parseIR (graph, &*g);
107
+
108
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
109
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
110
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
111
+ }
112
+
113
+ TEST (Evaluators, ATenArangeStartEndIntEvaluatesCorrectly) {
114
+ const auto graph = R"IR(
115
+ graph():
116
+ %0 : int = prim::Constant[value=1]()
117
+ %1 : int = prim::Constant[value=51]()
118
+ %2 : None = prim::Constant()
119
+ %3 : Tensor = aten::arange(%0, %1, %2, %2, %2, %2)
120
+ return (%3))IR" ;
121
+
122
+ auto g = std::make_shared<torch::jit::Graph>();
123
+ torch::jit::parseIR (graph, &*g);
124
+
125
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
126
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
127
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
128
+ }
129
+
130
+ TEST (Evaluators, ATenArangeStartEndFloatEvaluatesCorrectly) {
131
+ const auto graph = R"IR(
132
+ graph():
133
+ %0 : float = prim::Constant[value=1.5]()
134
+ %1 : float = prim::Constant[value=51.2]()
135
+ %2 : None = prim::Constant()
136
+ %3 : Tensor = aten::arange(%0, %1, %2, %2, %2, %2)
137
+ return (%3))IR" ;
138
+
139
+ auto g = std::make_shared<torch::jit::Graph>();
140
+ torch::jit::parseIR (graph, &*g);
141
+
142
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
143
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
144
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
145
+ }
146
+
147
+ TEST (Evaluators, ATenArangeStartEndStepIntEvaluatesCorrectly) {
148
+ const auto graph = R"IR(
149
+ graph():
150
+ %0 : int = prim::Constant[value=1]()
151
+ %1 : int = prim::Constant[value=51]()
152
+ %2 : int = prim::Constant[value=1]()
153
+ %3 : None = prim::Constant()
154
+ %4 : Tensor = aten::arange(%0, %1, %2, %3, %3, %3, %3)
155
+ return (%4))IR" ;
156
+
157
+ auto g = std::make_shared<torch::jit::Graph>();
158
+ torch::jit::parseIR (graph, &*g);
159
+
160
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
161
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
162
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
163
+ }
164
+
165
+ TEST (Evaluators, ATenArangeStartEndStepFloatEvaluatesCorrectly) {
166
+ const auto graph = R"IR(
167
+ graph():
168
+ %0 : float = prim::Constant[value=1.2]()
169
+ %1 : float = prim::Constant[value=51.6]()
170
+ %2 : float = prim::Constant[value=1.5]()
171
+ %3 : None = prim::Constant()
172
+ %4 : Tensor = aten::arange(%0, %1, %2, %3, %3, %3, %3)
173
+ return (%4))IR" ;
174
+
175
+ auto g = std::make_shared<torch::jit::Graph>();
176
+ torch::jit::parseIR (graph, &*g);
177
+
178
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
179
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
180
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
78
181
}
0 commit comments