Skip to content

Commit 68c934a

Browse files
committed
fix(//core/conversion/evaluator): Custom to IValue that handles int[]
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 0e90f78 commit 68c934a

File tree

4 files changed

+126
-3
lines changed

4 files changed

+126
-3
lines changed

core/conversion/evaluators/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ cc_library(
1616
"NodeEvaluatorRegistry.cpp",
1717
"prim.cpp",
1818
"aten.cpp",
19-
"eval_macros.h"
19+
"eval_macros.h",
20+
"eval_util.h",
21+
"eval_util.cpp"
2022
],
2123
deps = [
2224
"//core/util:prelude",
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#include "ATen/core/ivalue.h"
2+
#include "ATen/core/List.h"
3+
#include "core/util/prelude.h"
4+
#include "ATen/core/functional.h"
5+
6+
namespace trtorch {
7+
namespace core {
8+
namespace conversion {
9+
namespace evaluators {
10+
11+
//TODO: Switch back to PyTorch canonical implimentation
12+
c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v) {
13+
if (v->node()->kind() != torch::jit::prim::Constant || v->type()->cast<c10::FunctionType>()) {
14+
return c10::nullopt;
15+
}
16+
const torch::jit::Node* node = v->node();
17+
const c10::TypePtr& type = v->type();
18+
if (type->isSubtypeOf(c10::TensorType::get())) {
19+
return node->t(c10::attr::value);
20+
} else if (type->isSubtypeOf(c10::BoolType::get())) {
21+
return (bool)node->i(c10::attr::value);
22+
} else if (
23+
type->isSubtypeOf(c10::NumberType::get()) &&
24+
node->kindOf(c10::attr::value) == torch::jit::AttributeKind::i) {
25+
return node->i(c10::attr::value);
26+
} else if (
27+
type->isSubtypeOf(c10::NumberType::get()) &&
28+
node->kindOf(c10::attr::value) == torch::jit::AttributeKind::f) {
29+
return node->f(c10::attr::value);
30+
} else if (type->isSubtypeOf(c10::ListType::ofInts())) {
31+
try {
32+
const auto& is = node->is(c10::attr::value);
33+
return is;
34+
} catch (const std::exception& ex) {
35+
const auto& ival = node->ival(c10::attr::value);
36+
return ival;
37+
}
38+
} else if (type->isSubtypeOf(c10::ListType::ofFloats())) {
39+
try {
40+
const auto& fs = node->fs(c10::attr::value);
41+
return fs;
42+
} catch (const std::exception& ex) {
43+
const auto& ival = node->ival(c10::attr::value);
44+
return ival;
45+
}
46+
} else if (type->isSubtypeOf(c10::ListType::ofBools())) {
47+
const auto bs = c10::fmap<bool>(node->is(c10::attr::value));
48+
return bs;
49+
} else if (type->isSubtypeOf(c10::ListType::ofTensors())) {
50+
try {
51+
const auto& ts = node->ts(c10::attr::value);
52+
return ts;
53+
} catch (const std::exception& ex) {
54+
const auto& ival = node->ival(c10::attr::value);
55+
return ival;
56+
}
57+
} else if (type->isSubtypeOf(c10::ListType::ofStrings())) {
58+
try {
59+
const auto& ss = node->ss(c10::attr::value);
60+
auto vals = c10::impl::GenericList(c10::StringType::get());
61+
for (const auto& str : ss) {
62+
vals.push_back(str);
63+
}
64+
return vals;
65+
} catch (const std::exception& ex) {
66+
const auto& ival = node->ival(c10::attr::value);
67+
return ival;
68+
}
69+
} else if (
70+
type->cast<c10::ListType>() &&
71+
node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) {
72+
const auto& list = node->ival(c10::attr::value);
73+
TRTORCH_ASSERT(list.isList(), "Is not a list");
74+
return list;
75+
} else if (
76+
type->cast<c10::DictType>() &&
77+
node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) {
78+
const auto& dict = node->ival(c10::attr::value);
79+
TRTORCH_ASSERT(dict.isGenericDict(), "Is not a dict");
80+
return dict;
81+
} else if (
82+
type->cast<c10::TupleType>() &&
83+
node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) {
84+
const auto& tup = node->ival(c10::attr::value);
85+
TRTORCH_ASSERT(tup.isTuple(), "Is not a tuple");
86+
return tup;
87+
} else if (type == c10::StringType::get()) {
88+
const auto& s = node->s(c10::attr::value);
89+
return s;
90+
} else if (type == c10::DeviceObjType::get()) {
91+
auto d = c10::Device(node->s(c10::attr::value));
92+
return d;
93+
} else if (node->mustBeNone()) {
94+
return torch::jit::IValue();
95+
} else {
96+
std::stringstream ss;
97+
ss << "constant literal not supported for: " << type->str();
98+
throw std::runtime_error(ss.str());
99+
}
100+
}
101+
102+
} // namespace evaluators
103+
} // namespace conversion
104+
} // namespace core
105+
} // namespace trtorch
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#pragma once
2+
3+
#include "torch/csrc/jit/ir/ir.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace conversion {
8+
namespace evaluators {
9+
10+
c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v);
11+
12+
} // namespace evaluators
13+
} // namespace conversion
14+
} // namespace core
15+
} // namespace trtorch

core/conversion/evaluators/prim.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include <limits>
22

33
#include "torch/csrc/jit/ir/ir.h"
4-
#include "torch/csrc/jit/ir/constants.h"
4+
//#include "torch/csrc/jit/ir/constants.h"
55
#include "ATen/core/functional.h"
66
#include "ATen/core/ivalue.h"
77
#include "ATen/core/List.h"
@@ -11,6 +11,7 @@
1111

1212
#include "core/conversion/evaluators/evaluators.h"
1313
#include "core/conversion/evaluators/eval_macros.h"
14+
#include "core/conversion/evaluators/eval_util.h"
1415

1516
namespace trtorch {
1617
namespace core {
@@ -25,7 +26,7 @@ auto prim_registrations = RegisterNodeEvaluators()
2526
if (n->output()->type()->kind() == at::FunctionType::Kind) {
2627
return {};
2728
}
28-
return torch::jit::toIValue(n->output());
29+
return evaluators::toIValue(n->output());
2930
}
3031
}).evaluator({
3132
torch::jit::prim::NumToTensor,

0 commit comments

Comments
 (0)