diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 23193df7e4..aea53ff6b1 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -190,6 +190,8 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) { } } +void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n); + void MapIValues(ConversionCtx* ctx, c10::ArrayRef in_list, c10::ArrayRef out_list, int64_t in_offset, int64_t out_offset) { std::vector> input_output_pairs; std::transform(in_list.begin() + in_offset, in_list.end(), out_list.begin() + out_offset, @@ -204,6 +206,31 @@ void MapIValues(ConversionCtx* ctx, c10::ArrayRef in_l } } +void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n) { + auto condition = ctx->evaluated_value_map[n->input(0)].toBool(); + LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Evaluating block " << (int) condition); + auto b = condition ? n->blocks()[0] : n->blocks()[1]; + + for (const auto bn : b->nodes()) { + if (bn->kind() == torch::jit::prim::Loop) { + EvaluateLoopBlock(ctx, bn); + } else if (bn->kind() == torch::jit::prim::If) { + EvaluateConditionalBlock(ctx, bn); + } else { + TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile conditionals that are evaluatable at conversion time but node " << *bn << " cannot be evaluated.") + auto eval = EvaluateNode(ctx, bn); + if (!eval.value().isTensor()) { + LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be: " << eval.value()); + } else { + LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')'); + } + ctx->AssociateValueAndIValue(bn->output(0), eval.value()); + } + } + + MapIValues(ctx, b->outputs(), n->outputs(), 0, 0); +} + // TODO: With functionalization pass we may be able to make this into a regular evaluator later void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) { auto max_trip_count = ctx->evaluated_value_map[n->input(0)]; @@ -213,16 +240,21 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) { MapIValues(ctx, n->inputs(), n->outputs(), 2, 0); - LOG_DEBUG("(Loop Evaluation) Evaluating loop " << *n); - LOG_DEBUG("(Loop Evaluation) Max Trip Count: " << max_trip_count.toInt()); - LOG_DEBUG("(Loop Evaluation) Start Condition: " << start_cond.toBool()); - LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt()); + LOG_DEBUG(ctx->logger, "(Loop Evaluation) Evaluating loop " << *n); + LOG_DEBUG(ctx->logger, "(Loop Evaluation) Max Trip Count: " << max_trip_count.toInt()); + LOG_DEBUG(ctx->logger, "(Loop Evaluation) Start Condition: " << start_cond.toBool()); + LOG_DEBUG(ctx->logger, "(Loop Evaluation) Current Trip Count: " << trip_count.toInt()); while (start_cond.toBool() && trip_count.toInt() < max_trip_count.toInt()) { MapIValues(ctx, n->outputs(), n->blocks()[0]->inputs(), 0, 1); for (auto bn : n->blocks()[0]->nodes()) { - auto eval = EvaluateNode(ctx, bn); - if (eval) { + if (bn->kind() == torch::jit::prim::Loop) { + EvaluateLoopBlock(ctx, n); + } else if (bn->kind() == torch::jit::prim::If) { + EvaluateConditionalBlock(ctx, bn); + } else { + TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile loops that are evaluatable at conversion time but node " << *bn << " cannot be evaluated."); + auto eval = EvaluateNode(ctx, bn); if (!eval.value().isTensor()) { LOG_DEBUG(ctx->logger, "(Loop Evaluation) Found the value to be: " << eval.value()); } else { @@ -236,8 +268,8 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) { start_cond = ctx->evaluated_value_map[n->blocks()[0]->outputs()[0]]; auto new_trip_count = torch::jit::IValue(trip_count.toInt() + 1); trip_count.swap(new_trip_count); - LOG_DEBUG("(Loop Evaluation) Condition: " << start_cond.toBool()); - LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt()); + LOG_DEBUG(ctx->logger, "(Loop Evaluation) Condition: " << start_cond.toBool()); + LOG_DEBUG(ctx->logger, "(Loop Evaluation) Current Trip Count: " << trip_count.toInt()); } } @@ -255,6 +287,8 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver bool blacklisted = isNodeConversionBlacklisted(n); if (n->kind() == torch::jit::prim::Loop) { EvaluateLoopBlock(ctx, n); + } else if (n->kind() == torch::jit::prim::If) { + EvaluateConditionalBlock(ctx, n); } else if (to_eval) { auto eval = EvaluateNode(ctx, n); if (eval) { @@ -303,10 +337,10 @@ std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo buil std::set GetUnsupportedOpsInBlock(const torch::jit::Block* b ) { std::set unsupported_ops; for (const auto n : b->nodes()) { - if (n->kind() != torch::jit::prim::Loop && !OpSupported(n)) { + if (n->kind() != torch::jit::prim::Loop && n->kind() != torch::jit::prim::If && !OpSupported(n)) { auto schema = n->maybeSchema(); TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \ - << " (conversion.VerifyCoverterSupportForBlock"); + << " (conversion.VerifyCoverterSupportForBlock)"); std::stringstream ss; ss << *schema; unsupported_ops.insert(ss.str()); diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index 7037e9a512..2993ee593e 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -13,7 +13,7 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) { << "\n Operating Precision: " << s.op_precision \ << "\n Make Refittable Engine: " << s.refit \ << "\n Debuggable Engine: " << s.debug \ - << "\n Strict Type: " << s.strict_types \ + << "\n Strict Types: " << s.strict_types \ << "\n Allow GPU Fallback (if running on DLA): " << s.allow_gpu_fallback \ << "\n Min Timing Iterations: " << s.num_min_timing_iters \ << "\n Avg Timing Iterations: " << s.num_avg_timing_iters \ @@ -51,6 +51,9 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) case nvinfer1::DataType::kINT8: TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does support INT8"); cfg->setFlag(nvinfer1::BuilderFlag::kINT8); + if (!settings.strict_types) { + cfg->setFlag(nvinfer1::BuilderFlag::kFP16); + } input_type = nvinfer1::DataType::kFLOAT; TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the ExtraInfo struct with your calibrator"); cfg->setInt8Calibrator(settings.calibrator); diff --git a/core/conversion/converters/NodeConverterRegistry.cpp b/core/conversion/converters/NodeConverterRegistry.cpp index 7315b0203d..7d9e30c47d 100644 --- a/core/conversion/converters/NodeConverterRegistry.cpp +++ b/core/conversion/converters/NodeConverterRegistry.cpp @@ -48,6 +48,10 @@ class NodeConverterRegistry { bool RegisterConverter(torch::jit::FunctionSchema* signature, OpConverter& converter) { LOG_DEBUG("Registering converter for " << canonical_schema_string(*signature)); auto name = signature->operator_name(); + auto iter = converter_lut_.find(name); + if (iter != converter_lut_.end()) { + LOG_WARNING("Overriding already registered converter " << signature->name() << ", unexpected behavior may occur"); + } converter_lut_[name] = std::move(converter); return true; } diff --git a/core/conversion/evaluators/BUILD b/core/conversion/evaluators/BUILD index 38bb7bb0d3..cbc5318342 100644 --- a/core/conversion/evaluators/BUILD +++ b/core/conversion/evaluators/BUILD @@ -15,7 +15,8 @@ cc_library( srcs = [ "NodeEvaluatorRegistry.cpp", "prim.cpp", - "aten.cpp" + "aten.cpp", + "eval_macros.h" ], deps = [ "//core/util:prelude", diff --git a/core/conversion/evaluators/NodeEvaluatorRegistry.cpp b/core/conversion/evaluators/NodeEvaluatorRegistry.cpp index b264bf647b..ac1673436a 100644 --- a/core/conversion/evaluators/NodeEvaluatorRegistry.cpp +++ b/core/conversion/evaluators/NodeEvaluatorRegistry.cpp @@ -30,6 +30,10 @@ class NodeEvaluatorRegistry { public: void RegisterEvaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) { LOG_DEBUG("Registering evaluator for " << node_kind.toQualString()); + auto iter = evaluator_lut_.find(node_kind); + if (iter != evaluator_lut_.end()) { + TRTORCH_THROW_ERROR("Attempting to override already registered evaluator " << node_kind.toQualString() << ", merge implementations instead"); + } evaluator_lut_[node_kind] = std::move(eval_reg); } diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 327a9c214a..1eacdc8c2d 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -8,6 +8,7 @@ #include "torch/torch.h" #include "core/conversion/evaluators/evaluators.h" +#include "core/conversion/evaluators/eval_macros.h" namespace trtorch { namespace core { @@ -15,7 +16,6 @@ namespace conversion { namespace evaluators { namespace { - int64_t normalizeIndex(int64_t idx, int64_t list_size) { if (idx < 0) { // Handle negative indexing @@ -24,7 +24,90 @@ int64_t normalizeIndex(int64_t idx, int64_t list_size) { return idx; } -auto aten_registrations = RegisterNodeEvaluators() +DEFINE_GENERIC_TWO_INPUT_EVALUATOR( + eq, + "aten::eq", + a == b, + std::set({ + "aten::eq.bool(bool a, bool b) -> (bool)", + "aten::eq.int(int a, int b) -> (bool)", + "aten::eq.float(float a, float b) -> (bool)", + "aten::eq.int_float(int a, float b) -> (bool)", + "aten::eq.float_int(float a, int b) -> (bool)", + }) +); + +DEFINE_GENERIC_TWO_INPUT_EVALUATOR( + ne, + "aten::ne", + a != b, + std::set({ + "aten::ne.bool(bool a, bool b) -> (bool)", + "aten::ne.int(int a, int b) -> (bool)", + "aten::ne.float(float a, float b) -> (bool)", + "aten::ne.int_float(int a, float b) -> (bool)", + "aten::ne.float_int(float a, int b) -> (bool)", + }) +); + +DEFINE_GENERIC_TWO_INPUT_EVALUATOR( + lt, + "aten::lt", + a < b, + std::set({ + "aten::lt.bool(bool a, bool b) -> (bool)", + "aten::lt.int(int a, int b) -> (bool)", + "aten::lt.float(float a, float b) -> (bool)", + "aten::lt.int_float(int a, float b) -> (bool)", + "aten::lt.float_int(float a, int b) -> (bool)", + }) +); + +DEFINE_GENERIC_TWO_INPUT_EVALUATOR( + gt, + "aten::gt", + a > b, + std::set({ + "aten::gt.bool(bool a, bool b) -> (bool)", + "aten::gt.int(int a, int b) -> (bool)", + "aten::gt.float(float a, float b) -> (bool)", + "aten::gt.int_float(int a, float b) -> (bool)", + "aten::gt.float_int(float a, int b) -> (bool)", + }) +); + +DEFINE_GENERIC_TWO_INPUT_EVALUATOR( + le, + "aten::le", + a <= b, + std::set({ + "aten::le.bool(bool a, bool b) -> (bool)", + "aten::le.int(int a, int b) -> (bool)", + "aten::le.float(float a, float b) -> (bool)", + "aten::le.int_float(int a, float b) -> (bool)", + "aten::le.float_int(float a, int b) -> (bool)", + }) +); + +DEFINE_GENERIC_TWO_INPUT_EVALUATOR( + ge, + "aten::ge", + a >= b, + std::set({ + "aten::ge.bool(bool a, bool b) -> (bool)", + "aten::ge.int(int a, int b) -> (bool)", + "aten::ge.float(float a, float b) -> (bool)", + "aten::ge.int_float(int a, float b) -> (bool)", + "aten::ge.float_int(float a, int b) -> (bool)", + }) +); + +DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(and, "aten::__and__", a && b, bool, {"aten::__and__(int a, int b) -> (bool)"}); +DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(or, "aten::__or__", a || b, bool, {"aten::__or__(int a, int b) -> (bool)"}); +DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(xor, "aten::__xor__", a != b, bool, {"aten::__xor__(int a, int b) -> (bool)"}); +DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(int_div, "aten::__round_to_zero_floordiv", a / b, int64_t, {"aten::__round_to_zero_floordiv(int a, int b) -> (int)"}); + +auto aten_registrations TRTORCH_UNUSED = RegisterNodeEvaluators() .evaluator({ c10::Symbol::fromQualString("aten::zeros"), // aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor) @@ -37,38 +120,6 @@ auto aten_registrations = RegisterNodeEvaluators() auto out_tensor = torch::zeros(args.at(n->input(0)).unwrapToIntList().vec(), options); return out_tensor; } - }).evaluator({ - c10::Symbol::fromQualString("aten::add"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - auto a = args.at(n->input(0)).unwrapToInt(); - auto b = args.at(n->input(1)).unwrapToInt(); - return a + b; - }, - EvalOptions().validSchemas({"aten::add.int(int a, int b) -> (int)"}) - }).evaluator({ - c10::Symbol::fromQualString("aten::mul"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - auto a = args.at(n->input(0)).unwrapToInt(); - auto b = args.at(n->input(1)).unwrapToInt(); - return a * b; - }, - EvalOptions().validSchemas({"aten::mul.int(int a, int b) -> (int)"}) - }).evaluator({ - c10::Symbol::fromQualString("aten::sub"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - auto a = args.at(n->input(0)).unwrapToInt(); - auto b = args.at(n->input(1)).unwrapToInt(); - return a - b; - }, - EvalOptions().validSchemas({"aten::sub.int(int a, int b) -> (int)"}) - }).evaluator({ - c10::Symbol::fromQualString("aten::__round_to_zero_floordiv"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - auto a = args.at(n->input(0)).unwrapToInt(); - auto b = args.at(n->input(1)).unwrapToInt(); - return a / b; - }, - EvalOptions().validSchemas({"aten::__round_to_zero_floordiv(int a, int b) -> (int)"}) }).evaluator({ c10::Symbol::fromQualString("aten::slice"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { @@ -139,7 +190,7 @@ auto aten_registrations = RegisterNodeEvaluators() }).evaluator({ c10::Symbol::fromQualString("aten::__getitem__"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - auto list = args.at(n->input(0)).unwrapToIntList(); + auto list = args.at(n->input(0)).IValue()->to>(); auto idx = args.at(n->input(1)).unwrapToInt(); const int64_t list_size = list.size(); @@ -153,8 +204,8 @@ auto aten_registrations = RegisterNodeEvaluators() }).evaluator({ c10::Symbol::fromQualString("aten::append"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - auto list = args.at(n->input(0)).unwrapToIntList(); - auto el = args.at(n->input(1)).unwrapToInt(); + auto list = args.at(n->input(0)).IValue()->to>(); + auto el = args.at(n->input(1)).IValue(); list.push_back(std::move(el)); return list; @@ -172,6 +223,226 @@ auto aten_registrations = RegisterNodeEvaluators() EvalOptions().validSchemas({ "aten::neg.int(int a) -> (int)", }) + }).evaluator({ + c10::Symbol::fromQualString("aten::add"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + if (args.at(n->input(0)).IValue()->isInt()) { + auto a = args.at(n->input(0)).unwrapToInt(); + auto b = args.at(n->input(1)).unwrapToInt(); + return a + b; + } else if (args.at(n->input(0)).IValue()->isDouble()) { + auto a = args.at(n->input(0)).unwrapToDouble(); + auto b = args.at(n->input(1)).unwrapToDouble(); + return a + b; + } else { + TRTORCH_THROW_ERROR("Unimplemented data type for aten::add evaluator: " << args.at(n->input(0)).IValue()->type()->str()); + return {}; + } + }, + EvalOptions().validSchemas({ + "aten::add.int(int a, int b) -> (int)", + "aten::add.float(float a, float b) -> (float)" + }) + }).evaluator({ + c10::Symbol::fromQualString("aten::mul"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + if (args.at(n->input(0)).IValue()->isInt()) { + auto a = args.at(n->input(0)).unwrapToInt(); + auto b = args.at(n->input(1)).unwrapToInt(); + return a * b; + } else if (args.at(n->input(0)).IValue()->isDouble()) { + auto a = args.at(n->input(0)).unwrapToDouble(); + auto b = args.at(n->input(1)).unwrapToDouble(); + return a * b; + } else { + TRTORCH_THROW_ERROR("Unimplemented data type for aten::mul evaluator: " << args.at(n->input(0)).IValue()->type()->str()); + return {}; + } + }, + EvalOptions().validSchemas({ + "aten::mul.int(int a, int b) -> (int)", + "aten::mul.float(float a, float b) -> (float)" + }) + }).evaluator({ + c10::Symbol::fromQualString("aten::sub"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + if (args.at(n->input(0)).IValue()->isInt()) { + auto a = args.at(n->input(0)).unwrapToInt(); + auto b = args.at(n->input(1)).unwrapToInt(); + return a - b; + } else if (args.at(n->input(0)).IValue()->isDouble()) { + auto a = args.at(n->input(0)).unwrapToDouble(); + auto b = args.at(n->input(1)).unwrapToDouble(); + return a - b; + } else { + TRTORCH_THROW_ERROR("Unimplemented data type for aten::sub evaluator: " << args.at(n->input(0)).IValue()->type()->str()); + return {}; + } + }, + EvalOptions().validSchemas({ + "aten::sub.float(float a, float b) -> (float)", + "aten::sub.int(int a, int b) -> (int)", + }) + }).evaluator({ + c10::Symbol::fromQualString("aten::Bool"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + if (args.at(n->input(0)).IValue()->isInt()) { + auto a = args.at(n->input(0)).unwrapToInt(); + return (bool) a; + } else if (args.at(n->input(0)).IValue()->isDouble()) { + auto a = args.at(n->input(0)).unwrapToDouble(); + return (bool) a; + } else { + TRTORCH_THROW_ERROR("Unimplemented data type for aten::Bool evaluator: " << args.at(n->input(0)).IValue()->type()->str()); + return {}; + } + }, + EvalOptions().validSchemas({ + "aten::Bool.int(int a) -> (bool)", + "aten::Bool.float(float b) -> (bool)" + }) + }).evaluator({ + c10::Symbol::fromQualString("aten::Float"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + if (args.at(n->input(0)).IValue()->isInt()) { + auto a = args.at(n->input(0)).unwrapToInt(); + return (float) a; + } else if (args.at(n->input(0)).IValue()->isDouble()) { + auto a = args.at(n->input(0)).unwrapToDouble(); + return (float) a; + } else if (args.at(n->input(0)).IValue()->isBool()) { + auto a = args.at(n->input(0)).unwrapToBool(); + return (float) a; + } else { + TRTORCH_THROW_ERROR("Unimplemented data type for aten::Float evaluator: " << args.at(n->input(0)).IValue()->type()->str()); + return {}; + } + }, + EvalOptions().validSchemas({ + "aten::Float.Scalar(Scalar a) -> float", + "aten::Float.int(int a) -> float", + "aten::Float.bool(bool a) -> float", + }) + }).evaluator({ + c10::Symbol::fromQualString("aten::__not__"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + auto el = args.at(n->input(0)).unwrapToBool(); + + return !el; + }, + EvalOptions().validSchemas({ + "aten::__not__(bool self) -> bool", + }) + }).evaluator({ + c10::Symbol::fromQualString("aten::__is__"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + auto self = args.at(n->input(0)).IValue(); + auto obj = args.at(n->input(1)).IValue(); + + return self->isSameIdentity(*obj); + }, + EvalOptions().validSchemas({ + "aten::__is__(t1 self, t2 obj) -> bool", + }) + }).evaluator({ + c10::Symbol::fromQualString("aten::__isnot__"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + auto self = args.at(n->input(0)).IValue(); + auto obj = args.at(n->input(1)).IValue(); + + return !self->isSameIdentity(*obj); + }, + EvalOptions().validSchemas({ + "aten::__isnot__(t1 self, t2 obj) -> bool", + }) + }).evaluator({ + c10::Symbol::fromQualString("aten::numel"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + LOG_WARNING("There may be undefined behavior using dynamic shape and aten::numel"); + auto tensor_var = args.at(n->input(0)); + if (tensor_var.isITensor()) { + auto tensor = tensor_var.ITensor(); + return util::volume(tensor->getDimensions()); + } else { + auto tensor = tensor_var.unwrapToTensor(); + return tensor.numel(); + } + }, + EvalOptions().validSchemas({ + "aten::numel(Tensor self) -> int", + }) + }).evaluator({ + c10::Symbol::fromQualString("aten::dim"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + auto tensor_var = args.at(n->input(0)); + if (tensor_var.isITensor()) { + auto tensor = tensor_var.ITensor(); + return tensor->getDimensions().nbDims; + } else { + auto tensor = tensor_var.unwrapToTensor(); + return tensor.dim(); + } + }, + EvalOptions().validSchemas({ + "aten::dim(Tensor self) -> int", + }) + }).evaluator({ + c10::Symbol::fromQualString("aten::div"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + if (args.at(n->input(0)).IValue()->isInt()) { + auto a = args.at(n->input(0)).unwrapToInt(); + auto b = args.at(n->input(1)).unwrapToInt(); + return static_cast(a) / static_cast(b); + } else if (args.at(n->input(0)).IValue()->isDouble()) { + auto a = args.at(n->input(0)).unwrapToDouble(); + auto b = args.at(n->input(1)).unwrapToDouble(); + return a / b; + } else { + TRTORCH_THROW_ERROR("Unimplemented data type for aten::div evaluator: " << args.at(n->input(0)).IValue()->type()->str()); + return {}; + } + }, + EvalOptions().validSchemas({ + "aten::div.Scalar(Scalar a, Scalar b) -> (float)", + }) + }).evaluator({ + c10::Symbol::fromQualString("aten::floordiv"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + if (args.at(n->input(0)).IValue()->isInt()) { + auto a = args.at(n->input(0)).unwrapToInt(); + auto b = args.at(n->input(1)).unwrapToInt(); + return std::floor(a / b); + } else if (args.at(n->input(0)).IValue()->isDouble()) { + auto a = args.at(n->input(0)).unwrapToDouble(); + auto b = args.at(n->input(1)).unwrapToDouble(); + return std::floor(a / b); + } else { + TRTORCH_THROW_ERROR("Unimplemented data type for aten::floordiv evaluator: " << args.at(n->input(0)).IValue()->type()->str()); + return {}; + } + }, + EvalOptions().validSchemas({ + "aten::floordiv.float(float a, float b) -> (int)", + "aten::floordiv.int(int a, int b) -> (int)", + }) + }).evaluator({ + c10::Symbol::fromQualString("aten::floor"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + auto el = args.at(n->input(0)).unwrapToDouble(); + + return std::floor(el); + }, + EvalOptions().validSchemas({ + "aten::floor.float(float a) -> (int)", + }) + }).evaluator({ + c10::Symbol::fromQualString("aten::warn"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + auto warning = args.at(n->input(0)).IValue()->toString(); + LOG_WARNING(warning); + return {}; + }, + EvalOptions() }); } } // namespace evaluators diff --git a/core/conversion/evaluators/eval_macros.h b/core/conversion/evaluators/eval_macros.h new file mode 100644 index 0000000000..7492594197 --- /dev/null +++ b/core/conversion/evaluators/eval_macros.h @@ -0,0 +1,77 @@ +#pragma once + +#include "core/conversion/evaluators/evaluators.h" + +#define DEFINE_GENERIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \ + auto name##_registrations TRTORCH_UNUSED = \ + RegisterNodeEvaluators().evaluator({ \ + c10::Symbol::fromQualString(node_kind), \ + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { \ + if (args.at(n->input(0)).IValue()->isInt()) { \ + auto a = args.at(n->input(0)).unwrapToInt(); \ + if (args.at(n->input(1)).IValue()->isInt()) { \ + auto b = args.at(n->input(1)).unwrapToInt(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isDouble()) { \ + auto b = args.at(n->input(1)).unwrapToDouble(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isBool()) { \ + auto b = args.at(n->input(1)).unwrapToBool(); \ + return operation; \ + } else { \ + TRTORCH_THROW_ERROR("Unimplemented data type for " << node_kind << " evaluator b arg:" \ + << args.at(n->input(1)).IValue()->type()->str()); \ + return {}; \ + } \ + } else if (args.at(n->input(0)).IValue()->isDouble()) { \ + auto a = args.at(n->input(0)).unwrapToDouble(); \ + if (args.at(n->input(1)).IValue()->isInt()) { \ + auto b = args.at(n->input(1)).unwrapToInt(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isDouble()) { \ + auto b = args.at(n->input(1)).unwrapToDouble(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isBool()) { \ + auto b = args.at(n->input(1)).unwrapToBool(); \ + return operation; \ + } else { \ + TRTORCH_THROW_ERROR("Unimplemented data type for " << node_kind << " evaluator b arg:" \ + << args.at(n->input(1)).IValue()->type()->str()); \ + return {}; \ + } \ + } else if (args.at(n->input(0)).IValue()->isBool()) { \ + auto a = args.at(n->input(0)).unwrapToBool(); \ + if (args.at(n->input(1)).IValue()->isInt()) { \ + auto b = args.at(n->input(1)).unwrapToInt(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isDouble()) { \ + auto b = args.at(n->input(1)).unwrapToDouble(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isBool()) { \ + auto b = args.at(n->input(1)).unwrapToBool(); \ + return operation; \ + } else { \ + TRTORCH_THROW_ERROR("Unimplemented data type for " << node_kind << " evaluator b arg:" \ + << args.at(n->input(1)).IValue()->type()->str()); \ + return {}; \ + } \ + } else { \ + TRTORCH_THROW_ERROR("Unimplemented data type for " << node_kind << " evaluator a arg: " \ + << args.at(n->input(0)).IValue()->type()->str()); \ + return {}; \ + } \ + }, \ + EvalOptions().validSchemas(schemas) \ + }); + +#define DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(node_kind, node_name, operation, type, schemas) \ + auto node_kind##_registrations TRTORCH_UNUSED = \ + RegisterNodeEvaluators().evaluator({ \ + c10::Symbol::fromQualString(node_name), \ + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { \ + auto a = args.at(n->input(0)).unwrapTo(); \ + auto b = args.at(n->input(1)).unwrapTo(); \ + return operation; \ + }, \ + EvalOptions().validSchemas(schemas) \ + }); diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index dddd4c304a..f89dfa02d7 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -10,6 +10,7 @@ #include "torch/torch.h" #include "core/conversion/evaluators/evaluators.h" +#include "core/conversion/evaluators/eval_macros.h" namespace trtorch { namespace core { @@ -97,18 +98,121 @@ auto prim_registrations = RegisterNodeEvaluators() }).evaluator({ c10::Symbol::fromQualString("prim::min"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - auto a = args.at(n->input(0)).unwrapToIntList(); - int64_t min = std::numeric_limits::max(); + if (n->inputs().size() == 1) { + auto a = args.at(n->input(0)).unwrapToIntList(); + int64_t min = std::numeric_limits::max(); - for (size_t i = 0; i < a.size(); i++) { - if (a[i] < min) { - min = a[i]; + for (size_t i = 0; i < a.size(); i++) { + if (a[i] < min) { + min = a[i]; + } } + + return min; + } else if (n->inputs().size() == 2) { + if (args.at(n->input(0)).IValue()->isInt()) { + auto a = args.at(n->input(0)).unwrapToInt(); + if (args.at(n->input(1)).IValue()->isInt()) { + auto b = args.at(n->input(1)).unwrapToInt(); + return a < b ? a : b; + } else if (args.at(n->input(1)).IValue()->isDouble()) { + auto b = args.at(n->input(1)).unwrapToDouble(); + return a < b ? a : b; + } else { + TRTORCH_THROW_ERROR("Unimplemented data type for " << n->kind().toQualString() << " evaluator b arg: " + << args.at(n->input(1)).IValue()->type()->str()); + return {}; + } + } else if (args.at(n->input(0)).IValue()->isDouble()) { + auto a = args.at(n->input(0)).unwrapToDouble(); + if (args.at(n->input(1)).IValue()->isInt()) { + auto b = args.at(n->input(1)).unwrapToInt(); + return a < b ? a : b; + } else if (args.at(n->input(1)).IValue()->isDouble()) { + auto b = args.at(n->input(1)).unwrapToDouble(); + return a < b ? a : b; + } else { + TRTORCH_THROW_ERROR("Unimplemented data type for " << n->kind().toQualString() << " evaluator b arg: " + << args.at(n->input(1)).IValue()->type()->str()); + return {}; + } + } else { + TRTORCH_THROW_ERROR("Unimplemented data type for " << n->kind().toQualString() << " evaluator a arg: " + << args.at(n->input(0)).IValue()->type()->str()); + return {}; + } + } else { + TRTORCH_THROW_ERROR("Unimplemented " << n->kind().toQualString() << " evaluator case"); + return {}; } + }, + EvalOptions().validSchemas({ + "prim::min.self_int(int[] self) -> (int)", + "prim::min.bool(bool a, bool b) -> (bool)", + "prim::min.int(int a, int b) -> (bool)", + "prim::min.float(float a, float b) -> (bool)", + "prim::min.int_float(int a, float b) -> (bool)", + "prim::min.float_int(float a, int b) -> (bool)", + }) + }).evaluator({ + c10::Symbol::fromQualString("prim::max"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + if (n->inputs().size() == 1) { + auto a = args.at(n->input(0)).unwrapToIntList(); + int64_t max = std::numeric_limits::min(); - return min; + for (size_t i = 0; i < a.size(); i++) { + if (a[i] > max) { + max = a[i]; + } + } + + return max; + } else if (n->inputs().size() == 2) { + if (args.at(n->input(0)).IValue()->isInt()) { + auto a = args.at(n->input(0)).unwrapToInt(); + if (args.at(n->input(1)).IValue()->isInt()) { + auto b = args.at(n->input(1)).unwrapToInt(); + return a > b ? a : b; + } else if (args.at(n->input(1)).IValue()->isDouble()) { + auto b = args.at(n->input(1)).unwrapToDouble(); + return a > b ? a : b; + } else { + TRTORCH_THROW_ERROR("Unimplemented data type for " << n->kind().toQualString() << " evaluator b arg: " + << args.at(n->input(1)).IValue()->type()->str()); + return {}; + } + } else if (args.at(n->input(0)).IValue()->isDouble()) { + auto a = args.at(n->input(0)).unwrapToDouble(); + if (args.at(n->input(1)).IValue()->isInt()) { + auto b = args.at(n->input(1)).unwrapToInt(); + return a > b ? a : b; + } else if (args.at(n->input(1)).IValue()->isDouble()) { + auto b = args.at(n->input(1)).unwrapToDouble(); + return a > b ? a : b; + } else { + TRTORCH_THROW_ERROR("Unimplemented data type for " << n->kind().toQualString() << " evaluator b arg: " + << args.at(n->input(1)).IValue()->type()->str()); + return {}; + } + } else { + TRTORCH_THROW_ERROR("Unimplemented data type for " << n->kind().toQualString() << " evaluator a arg: " + << args.at(n->input(0)).IValue()->type()->str()); + return {}; + } + } else { + TRTORCH_THROW_ERROR("Unimplemented " << n->kind().toQualString() << " evaluator case"); + return {}; + } }, - EvalOptions().validSchemas({"prim::min.self_int(int[] self) -> (int)"}) + EvalOptions().validSchemas({ + "prim::max.self_int(int[] self) -> (int)", + "prim::max.bool(bool a, bool b) -> (bool)", + "prim::max.int(int a, int b) -> (bool)", + "prim::max.float(float a, float b) -> (bool)", + "prim::max.int_float(int a, float b) -> (bool)", + "prim::max.float_int(float a, int b) -> (bool)", + }) }).evaluator({ c10::Symbol::fromQualString("prim::shape"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { @@ -125,6 +229,23 @@ auto prim_registrations = RegisterNodeEvaluators() EvalOptions().validSchemas({ "prim::shape(Tensor a) -> (int[])" }) + }).evaluator({ + c10::Symbol::fromQualString("prim::unchecked_cast"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + return *(args.at(n->input(0)).IValue()); + } + }).evaluator({ + c10::Symbol::fromQualString("prim::Uninitialized"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + return c10::IValue::uninitialized(); + } + }).evaluator({ + c10::Symbol::fromQualString("prim::RaiseException"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + auto exception = args.at(n->input(0)).IValue()->toString(); + TRTORCH_THROW_ERROR(exception); + return {}; + } }); } } // namespace evaluators