Skip to content

Commit 478803a

Browse files
zdevitofacebook-github-bot
authored andcommitted
Introduce type variables to implement generic list operators (#12040)
Summary: We generate specialized list operations for int, float, and Tensor lists so that small lists of integers like the arguments to conv do not involve tons of boxing code. This PR adds a fallback GenericList for List types that contain any other type. It does so by adding type variables to `jit::Type`, and machinery for matching/replacing the type variables during `tryMatchSchema` and operator lookup. It also modifies the builtin list ops to include a fallback that works on a GenericList object that simply holds IValues. This is distinguished from IValue's tuple type so that conversion to/from Python still happens losslessly. Pull Request resolved: #12040 Differential Revision: D10037098 Pulled By: zdevito fbshipit-source-id: 0c5f2864d12e7d33554bf34cc29e5fb700dde150
1 parent 75b1ae1 commit 478803a

16 files changed

+328
-110
lines changed

aten/src/ATen/core/ivalue.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
#include <ATen/core/ivalue.h>
22
#include <ATen/core/Formatting.h>
33

4-
#define TORCH_FORALL_TAGS(_) \
5-
_(None) \
6-
_(Tensor) _(Double) _(Int) _(Tuple) _(IntList) _(DoubleList) _(String) \
7-
_(TensorList) _(Blob)
4+
#define TORCH_FORALL_TAGS(_) \
5+
_(None) \
6+
_(Tensor) \
7+
_(Double) \
8+
_(Int) \
9+
_(Tuple) \
10+
_(IntList) \
11+
_(DoubleList) \
12+
_(String) \
13+
_(TensorList) \
14+
_(Blob) \
15+
_(GenericList)
816

917
namespace torch { namespace jit {
1018

aten/src/ATen/core/ivalue.h

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ struct CAFFE2_API ConstantString final : c10::intrusive_ptr_target {
3535

3636
// non-mutable list
3737
template <typename Elem>
38-
struct C10_EXPORT ConstantList final : c10::intrusive_ptr_target {
38+
struct C10_EXPORT ConstantList : c10::intrusive_ptr_target {
3939
private:
4040
const std::vector<Elem> elements_;
4141
public:
42+
typedef Elem ElemType;
4243
ConstantList(std::vector<Elem> elements_)
4344
: elements_(std::move(elements_)) {}
4445
static c10::intrusive_ptr<ConstantList<Elem>> create(std::vector<Elem> elements_) {
@@ -53,10 +54,16 @@ struct C10_EXPORT ConstantList final : c10::intrusive_ptr_target {
5354
};
5455

5556
struct IValue;
56-
using Tuple = ConstantList<IValue>;
57+
struct C10_EXPORT Tuple : public ConstantList<IValue> {
58+
using ConstantList<IValue>::ConstantList;
59+
static c10::intrusive_ptr<Tuple> create(std::vector<IValue> elements_) {
60+
return c10::make_intrusive<Tuple>(std::move(elements_));
61+
}
62+
};
5763
using IntList = ConstantList<int64_t>;
5864
using TensorList = ConstantList<at::Tensor>;
5965
using DoubleList = ConstantList<double>;
66+
using GenericList = ConstantList<IValue>;
6067

6168
// IValue is the generic tagged union used by the interpreter to hold
6269
// all value types.
@@ -65,10 +72,18 @@ using DoubleList = ConstantList<double>;
6572
// to mark whether that type is a subtype of c10::intrusive_ptr_target and needs
6673
// retain/release calls.
6774

68-
#define TORCH_FORALL_TAGS(_) \
69-
_(None) \
70-
_(Tensor) _(Double) _(Int) _(Tuple) _(IntList) _(DoubleList) _(String) \
71-
_(TensorList) _(Blob)
75+
#define TORCH_FORALL_TAGS(_) \
76+
_(None) \
77+
_(Tensor) \
78+
_(Double) \
79+
_(Int) \
80+
_(Tuple) \
81+
_(IntList) \
82+
_(DoubleList) \
83+
_(String) \
84+
_(TensorList) \
85+
_(Blob) \
86+
_(GenericList)
7287

7388
struct CAFFE2_API IValue final {
7489
IValue()
@@ -207,6 +222,7 @@ struct CAFFE2_API IValue final {
207222
const std::vector<int64_t>& toIntListRef() const;
208223
const std::vector<double>& toDoubleListRef() const;
209224
const std::vector<at::Tensor>& toTensorListRef() const;
225+
const std::vector<IValue>& toGenericListRef() const;
210226

211227
// ConstantString
212228
IValue(c10::intrusive_ptr<ConstantString> v);
@@ -247,6 +263,19 @@ struct CAFFE2_API IValue final {
247263
return toIntrusivePtr<TensorList>();
248264
}
249265

266+
//GenericList
267+
IValue(c10::intrusive_ptr<GenericList> v);
268+
IValue(std::vector<IValue> v);
269+
bool isGenericList() const { return Tag::GenericList == tag; }
270+
c10::intrusive_ptr<GenericList> toGenericList() && {
271+
AT_ASSERT(isGenericList());
272+
return moveToIntrusivePtr<GenericList>();
273+
}
274+
c10::intrusive_ptr<GenericList> toGenericList() const & {
275+
AT_ASSERT(isGenericList());
276+
return toIntrusivePtr<GenericList>();
277+
}
278+
250279
// None
251280
bool isNone() {
252281
return Tag::None == tag;
@@ -362,12 +391,14 @@ DEFINE_TO(int64_t, toInt)
362391
DEFINE_TO(c10::intrusive_ptr<DoubleList>, toDoubleList)
363392
DEFINE_TO(c10::intrusive_ptr<IntList>, toIntList)
364393
DEFINE_TO(c10::intrusive_ptr<TensorList>, toTensorList)
394+
DEFINE_TO(c10::intrusive_ptr<GenericList>, toGenericList)
365395
DEFINE_TO(c10::intrusive_ptr<ConstantString>, toString)
366396
DEFINE_TO(at::Scalar, toScalar)
367397
DEFINE_TO(bool, toInt)
368398
DEFINE_TO(std::vector<int64_t>, toIntListRef)
369399
DEFINE_TO(std::vector<double>, toDoubleListRef)
370400
DEFINE_TO(std::vector<at::Tensor>, toTensorListRef)
401+
DEFINE_TO(std::vector<IValue>, toGenericListRef)
371402

372403
#undef DEFINE_TO
373404

@@ -433,6 +464,14 @@ inline IValue::IValue(c10::intrusive_ptr<TensorList> v)
433464
inline IValue::IValue(std::vector<at::Tensor> v)
434465
: IValue(TensorList::create(std::move(v))) {}
435466

467+
inline IValue::IValue(c10::intrusive_ptr<GenericList> v)
468+
: tag(Tag::GenericList), is_intrusive_ptr(true) {
469+
payload.as_intrusive_ptr = v.release();
470+
}
471+
inline IValue::IValue(std::vector<IValue> v)
472+
: IValue(GenericList::create(std::move(v))) {}
473+
474+
436475
inline const std::vector<int64_t>& IValue::toIntListRef() const {
437476
return toIntList()->elements();
438477
}
@@ -445,5 +484,9 @@ inline const std::vector<at::Tensor>& IValue::toTensorListRef() const {
445484
return toTensorList()->elements();
446485
}
447486

487+
inline const std::vector<IValue>& IValue::toGenericListRef() const {
488+
return toGenericList()->elements();
489+
}
490+
448491

449492
}}

test/test_jit.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2645,18 +2645,23 @@ def stuff3(x):
26452645
return torch.ones(x), x
26462646
self.checkScript(stuff3, ([3, 2],))
26472647

2648-
def test_nested_list_error(self):
2649-
with self.assertRaisesRegex(RuntimeError, "Lists can only contain"):
2650-
@torch.jit.script
2651-
def foo(x):
2652-
# type: (Tuple[List[List[int]]]) -> int
2653-
return 4
2648+
def test_nested_list(self):
2649+
def foo(z):
2650+
# type: (Tuple[int, List[List[int]]]) -> int
2651+
x, y = z
2652+
return y[0][1]
2653+
self.checkScript(foo, ((1, [[1, 2], [3, 4]]),))
2654+
2655+
def test_nested_list_construct(self):
2656+
def foo():
2657+
return [[4]] + [[4, 5]]
2658+
self.checkScript(foo, ())
26542659

2655-
def test_nested_list_construct_error(self):
2656-
with self.assertRaisesRegex(RuntimeError, "Lists can only contain"):
2660+
def test_generic_list_errors(self):
2661+
with self.assertRaisesRegex(RuntimeError, "previously matched to type"):
26572662
@torch.jit.script
26582663
def foo(x):
2659-
return [[4]]
2664+
return [[x]] + [[1]]
26602665

26612666
def test_script_cu(self):
26622667
cu = torch.jit.CompilationUnit('''

torch/csrc/jit/argument_spec.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ static_assert(sizeof(ArgumentInfo) == sizeof(ArgumentInfo::plain_data_type),
6161
struct ArgumentSpec {
6262
ArgumentSpec(bool with_grad, at::ArrayRef<IValue> inputs, size_t num_flat_inputs) {
6363
hash_code = num_flat_inputs;
64-
6564
args.resize(num_flat_inputs);
6665
size_t offset = 0;
6766
for (size_t i = 0; i < inputs.size(); ++i) {

torch/csrc/jit/export.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,8 @@ void ModuleEncoder::EncodeTypeInfo(
564564
type_proto->set_denotation("GeneratorType");
565565
} else if (kind == TypeKind::StringType) {
566566
type_proto->set_denotation("StringType");
567+
} else if (kind == TypeKind::VarType) {
568+
type_proto->set_denotation("TypeVar:" + type->expect<VarType>()->name());
567569
} else {
568570
throw std::runtime_error("unexpected type kind");
569571
}

torch/csrc/jit/import.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,10 @@ TypePtr ModuleDecoder::buildType(const onnx::TypeProto& type_proto) {
260260
return NoneType::get();
261261
} else if (kind == "GeneratorType") {
262262
return GeneratorType::get();
263-
}else if (kind == "StringType") {
263+
} else if (kind == "StringType") {
264264
return StringType::get();
265+
} else if (kind.find("TypeVar:") == 0) {
266+
return VarType::create(kind.substr(strlen("TypeVar:")));
265267
} else {
266268
throw std::runtime_error("unexpected string for type kind");
267269
}

torch/csrc/jit/operator.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,14 @@ struct SchemaParser {
6262
auto tok = L.expect(TK_IDENT);
6363
auto text = tok.text();
6464
auto it = type_map.find(text);
65-
if(it == type_map.end())
65+
if(it == type_map.end()) {
66+
if(text.size() > 0 && islower(text[0])) {
67+
// lower case identifiers that are not otherwise valid types
68+
// are treated as type variables
69+
return VarType::create(text);
70+
}
6671
throw ErrorReport(tok.range) << "unknown type specifier";
72+
}
6773
return it->second;
6874
}
6975
void parseArgumentType(std::vector<Argument>& arguments) {
@@ -358,9 +364,16 @@ bool Operator::matches(const Node* node) const {
358364
if(actuals.size() < formals.size())
359365
return false;
360366

367+
368+
TypeEnv type_env;
361369
for(size_t i = 0; i < formals.size(); ++i) {
362-
// mismatched input type
363-
if (!actuals[i]->type()->isSubtypeOf(formals[i].type)) {
370+
try {
371+
TypePtr formal = matchTypeVariables(formals[i].type, actuals[i]->type(), type_env);
372+
// mismatched input type
373+
if (!actuals[i]->type()->isSubtypeOf(formal)) {
374+
return false;
375+
}
376+
} catch(TypeMatchError& err) {
364377
return false;
365378
}
366379
}

torch/csrc/jit/pybind_utils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ inline IValue toIValue(py::handle obj, const TypePtr& type) {
142142
}
143143
case TypeKind::NumberType:
144144
case TypeKind::GeneratorType:
145+
case TypeKind::VarType:
145146
break;
146147
}
147148
AT_ERROR("Missing cases in toIValue for type: ", type->str(), "! File a bug report.");
@@ -199,6 +200,14 @@ inline py::object toPyObject(IValue&& ivalue) {
199200
return py::cast(ivalue.toDoubleListRef());
200201
} else if (ivalue.isTensorList()) {
201202
return py::cast(ivalue.toTensorListRef());
203+
} else if (ivalue.isGenericList()) {
204+
auto list = ivalue.toGenericList();
205+
const auto & elements = list->elements();
206+
py::list t { elements.size() };
207+
for (size_t i = 0; i < elements.size(); ++i) {
208+
t[i] = toPyObject(IValue{elements[i]});
209+
}
210+
return t;
202211
} else if (ivalue.isTuple()) {
203212
auto tuple = ivalue.toTuple();
204213
const auto & elements = tuple->elements();

torch/csrc/jit/python_ir.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,8 @@ void initPythonIRBindings(PyObject * module_) {
455455
return "StringType";
456456
case TypeKind::GeneratorType:
457457
return "GeneratorType";
458+
case TypeKind::VarType:
459+
return "VarType";
458460
}
459461
// not reachable, but some compilers complain
460462
AT_ERROR("Unknown Type Kind");

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -399,9 +399,17 @@ RegisterOperators reg({
399399
return 0;
400400
};
401401
} else {
402-
std::stringstream ss;
403-
ss << "unsupported list type: " << *lt->getElementType();
404-
throw std::runtime_error(ss.str());
402+
return [=](Stack& stack) {
403+
const size_t stack_size = stack.size();
404+
std::vector<IValue> vals;
405+
vals.reserve(num_inputs);
406+
for (size_t i = stack_size - num_inputs; i < stack_size; ++i) {
407+
vals.push_back(std::move(stack[i]));
408+
}
409+
drop(stack, num_inputs);
410+
push(stack, std::move(vals));
411+
return 0;
412+
};
405413
}
406414
}),
407415
});
@@ -506,11 +514,7 @@ Operation listEq(Node* node) {
506514
T a;
507515
T b;
508516
pop(stack, a, b);
509-
if (a->elements() == b->elements()) {
510-
push(stack, 1);
511-
} else {
512-
push(stack, 0);
513-
}
517+
push(stack, a->elements() == b->elements() ? 1 : 0);
514518
return 0;
515519
};
516520
}
@@ -604,31 +608,25 @@ Operation listSlice(Node* node) {
604608
}
605609

606610
RegisterOperators reg2({
607-
Operator("aten::select(int[] a, int b) -> int", listSelect<Shared<IntList>>),
608-
Operator("aten::select(float[] a, int b) -> float", listSelect<Shared<DoubleList>>),
609-
Operator("aten::select(Tensor[] a, int b) -> Tensor", listSelect<Shared<TensorList>>),
610611

611-
Operator("aten::len(int[] a) -> int", listLen<Shared<IntList>>),
612-
Operator("aten::len(float[] a) -> int", listLen<Shared<DoubleList>>),
613-
Operator("aten::len(Tensor[] a) -> int", listLen<Shared<TensorList>>),
612+
#define CREATE_LIST_OPS(decl_type, c_type) \
613+
Operator("aten::select(" decl_type "[] a, int b) -> " decl_type, listSelect<Shared<c_type>>), \
614+
Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
615+
Operator("aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type "[]", listAdd<Shared<c_type>, c_type::ElemType>), \
616+
Operator( \
617+
"aten::slice(" decl_type "[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type "[]", \
618+
listSlice<Shared<c_type>, c_type::ElemType>),
619+
620+
621+
CREATE_LIST_OPS("int", IntList)
622+
CREATE_LIST_OPS("float", DoubleList)
623+
CREATE_LIST_OPS("Tensor", TensorList)
624+
CREATE_LIST_OPS("t", GenericList)
614625

615626
Operator("aten::eq(int[] a, int[] b) -> int", listEq<Shared<IntList>>),
616627
Operator("aten::eq(float[] a, float[] b) -> int", listEq<Shared<DoubleList>>),
617628
Operator("aten::eq(Tensor[] a, Tensor[] b) -> int", listEq<Shared<TensorList>>),
618629

619-
Operator("aten::add(int[] a, int[] b) -> int[]", listAdd<Shared<IntList>, int64_t>),
620-
Operator("aten::add(float[] a, float[] b) -> float[]", listAdd<Shared<DoubleList>, double>),
621-
Operator("aten::add(Tensor[] a, Tensor[] b) -> Tensor[]", listAdd<Shared<TensorList>, at::Tensor>),
622-
623-
Operator(
624-
"aten::slice(int[] l, int start, int end=9223372036854775807, int step=1) -> int[]",
625-
listSlice<Shared<IntList>, int64_t>),
626-
Operator(
627-
"aten::slice(float[] l, int start, int end=9223372036854775807, int step=1) -> float[]",
628-
listSlice<Shared<DoubleList>, double>),
629-
Operator(
630-
"aten::slice(Tensor[] l, int start, int end=9223372036854775807, int step=1) -> Tensor[]",
631-
listSlice<Shared<TensorList>, at::Tensor>),
632630

633631
DEFINE_BINARY_OP(aten::add, a + b)
634632
DEFINE_BINARY_OP(aten::sub, a - b)

0 commit comments

Comments
 (0)