Skip to content

Commit dcb1474

Browse files
committed
feat(//core/conversion): Adds the ability to evaluate loops
Note: This is ungaurded currently, loops either must be able to be evaluated at compile time or the module is not supported. Upcomming work on RNNs will add support for more types of loops Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent d351717 commit dcb1474

File tree

1 file changed

+79
-7
lines changed

1 file changed

+79
-7
lines changed

core/conversion/conversion.cpp

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,57 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
190190
}
191191
}
192192

193+
void MapIValues(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_list, c10::ArrayRef<const torch::jit::Value*> out_list, int64_t in_offset, int64_t out_offset) {
194+
std::vector<std::pair<const torch::jit::Value*, const torch::jit::Value*>> input_output_pairs;
195+
std::transform(in_list.begin() + in_offset, in_list.end(), out_list.begin() + out_offset,
196+
std::back_inserter(input_output_pairs),
197+
[](auto in, auto out){
198+
return std::make_pair(in, out);
199+
});
200+
201+
for (auto p : input_output_pairs) {
202+
auto input = ctx->evaluated_value_map[p.first];
203+
ctx->evaluated_value_map[p.second] = torch::jit::IValue(input);
204+
}
205+
}
206+
207+
// TODO: With functionalization pass we may be able to make this into a regular evaluator later
208+
void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
209+
auto max_trip_count = ctx->evaluated_value_map[n->input(0)];
210+
auto start_cond = ctx->evaluated_value_map[n->input(1)];
211+
ctx->evaluated_value_map[n->blocks()[0]->inputs()[0]] = torch::jit::IValue(0);
212+
auto trip_count = ctx->evaluated_value_map[n->blocks()[0]->inputs()[0]];
213+
214+
MapIValues(ctx, n->inputs(), n->outputs(), 2, 0);
215+
216+
LOG_DEBUG("(Loop Evaluation) Evaluating loop " << *n);
217+
LOG_DEBUG("(Loop Evaluation) Max Trip Count: " << max_trip_count.toInt());
218+
LOG_DEBUG("(Loop Evaluation) Start Condition: " << start_cond.toBool());
219+
LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
220+
221+
while (start_cond.toBool() && trip_count.toInt() < max_trip_count.toInt()) {
222+
MapIValues(ctx, n->outputs(), n->blocks()[0]->inputs(), 0, 1);
223+
for (auto bn : n->blocks()[0]->nodes()) {
224+
auto eval = EvaluateNode(ctx, bn);
225+
if (eval) {
226+
if (!eval.value().isTensor()) {
227+
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Found the value to be: " << eval.value());
228+
} else {
229+
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
230+
}
231+
ctx->AssociateValueAndIValue(bn->output(0), eval.value());
232+
}
233+
}
234+
235+
MapIValues(ctx, n->blocks()[0]->outputs(), n->outputs(), 1, 0);
236+
start_cond = ctx->evaluated_value_map[n->blocks()[0]->outputs()[0]];
237+
auto new_trip_count = torch::jit::IValue(trip_count.toInt() + 1);
238+
trip_count.swap(new_trip_count);
239+
LOG_DEBUG("(Loop Evaluation) Condition: " << start_cond.toBool());
240+
LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
241+
}
242+
}
243+
193244
void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) {
194245
LOG_INFO(ctx->logger, "Converting Block");
195246

@@ -202,7 +253,19 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver
202253
for (const auto n : nodes) {
203254
bool to_eval = evaluators::shouldEvalAtConversionTime(n);
204255
bool blacklisted = isNodeConversionBlacklisted(n);
205-
if (!to_eval && !blacklisted) {
256+
if (n->kind() == torch::jit::prim::Loop) {
257+
EvaluateLoopBlock(ctx, n);
258+
} else if (to_eval) {
259+
auto eval = EvaluateNode(ctx, n);
260+
if (eval) {
261+
if (!eval.value().isTensor()) {
262+
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());
263+
} else {
264+
LOG_DEBUG(ctx->logger, "Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
265+
}
266+
ctx->AssociateValueAndIValue(n->output(0), eval.value());
267+
}
268+
} else if (!blacklisted) {
206269
// Should error out if something fails
207270
AddLayer(ctx, n);
208271
} else {
@@ -237,22 +300,29 @@ std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo buil
237300
return engine;
238301
}
239302

240-
bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
241-
bool supported = true;
303+
std::set<std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b ) {
242304
std::set<std::string> unsupported_ops;
243305
for (const auto n : b->nodes()) {
244-
if (!OpSupported(n)) {
306+
if (!OpSupported(n) && n->kind() != torch::jit::prim::Loop) {
245307
auto schema = n->maybeSchema();
246308
TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \
247309
<< " (conversion.VerifyCoverterSupportForBlock");
248310
std::stringstream ss;
249311
ss << *schema;
250312
unsupported_ops.insert(ss.str());
251-
supported = false;
313+
}
314+
for (const auto sub_b : n->blocks()) {
315+
auto sub_b_unsupported_ops = GetUnsupportedOpsInBlock(sub_b);
316+
unsupported_ops.insert(sub_b_unsupported_ops.begin(), sub_b_unsupported_ops.end());
252317
}
253318
}
319+
return unsupported_ops;
320+
}
321+
322+
bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
323+
auto unsupported_ops = GetUnsupportedOpsInBlock(b);
254324

255-
if (!supported) {
325+
if (unsupported_ops.size() != 0) {
256326
std::stringstream unsupported_msg;
257327
unsupported_msg << "Method requested cannot be compiled by TRTorch.\nUnsupported operators listed below:" << std::endl;
258328
for (auto s : unsupported_ops) {
@@ -261,8 +331,10 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
261331
unsupported_msg << "You can either implement converters for these ops in your application or request implementation" << std::endl;
262332
unsupported_msg << "https://www.github.com/nvidia/TRTorch/issues" << std::endl;
263333
LOG_ERROR(unsupported_msg.str());
334+
return false;
335+
} else {
336+
return true;
264337
}
265-
return supported;
266338
}
267339

268340
} // namespace conversion

0 commit comments

Comments
 (0)