Skip to content

Commit d66c617

Browse files
author
Elias Ellison
committed
Add support for legacy constructors in JIT
ghstack-source-id: 7cb4972 Pull Request resolved: #74785
1 parent aa11ac3 commit d66c617

File tree

5 files changed

+157
-1
lines changed

5 files changed

+157
-1
lines changed

aten/src/ATen/core/interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ namespace c10 {
6565
_(prim, Placeholder) /* debug */ \
6666
_(prim, Print) \
6767
_(prim, EmptyListLiteral) \
68+
_(prim, LegacyTypedConstructor) \
6869
_(prim, PythonOp) \
6970
_(prim, IgnoredPythonOp) \
7071
_(prim, Reverse) \

test/jit/test_misc.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,53 @@ def test_return():
258258

259259
FileCheck().check("Tensor[] = prim::ListConstruct").run(test_return.graph)
260260

261+
def test_legacy_tensor_constructor(self):
262+
# testing PyObject overload
263+
def test_all_dtypes():
264+
return (
265+
torch.BoolTensor([2]),
266+
torch.LongTensor([3]),
267+
torch.ByteTensor([4]),
268+
torch.CharTensor([5]),
269+
torch.DoubleTensor([6]),
270+
torch.FloatTensor([7]),
271+
torch.IntTensor([8]),
272+
torch.ShortTensor([1]),
273+
torch.HalfTensor([1]),
274+
)
275+
276+
self.checkScript(test_all_dtypes, ())
277+
278+
# now test empty overload
279+
def empty_overload():
280+
return torch.LongTensor(2, 3, 4)
281+
282+
eager = empty_overload()
283+
jit = torch.jit.script(empty_overload)()
284+
eager[:] = 1
285+
jit[:] = 1
286+
self.assertEqual(eager, jit)
287+
288+
def no_inputs():
289+
return torch.DoubleTensor()
290+
291+
self.checkScript(no_inputs, ())
292+
293+
# bad schema
294+
def multiple_args():
295+
return torch.LongTensor(1, [2])
296+
297+
with self.assertRaisesRegex(RuntimeError, "multiple positional arguments that were not all integers"):
298+
torch.jit.script(multiple_args)
299+
300+
# kwarg bad schema
301+
def bad_kwarg():
302+
return torch.LongTensor(hello="1")
303+
304+
with self.assertRaisesRegex(RuntimeError, "hello"):
305+
torch.jit.script(bad_kwarg)
306+
307+
261308
def test_broadcasting_list(self):
262309
"""
263310
Test BroadcastingList and torch.nn._size_N_t alias

torch/csrc/jit/frontend/ir_emitter.cpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
#include <ATen/core/interned_strings.h>
3939
#include <ATen/core/jit_type.h>
40+
#include <torch/csrc/jit/frontend/error_report.h>
4041
#include <atomic>
4142
#include <climits>
4243
#include <set>
@@ -3320,7 +3321,7 @@ struct to_ir {
33203321
auto sv = emitSugaredExpr(apply.callee(), 1);
33213322
auto loc = apply.callee().range();
33223323
if (auto special_form = dynamic_cast<SpecialFormValue*>(sv.get())) {
3323-
return emitApplySpecialForm(special_form->form(), apply, type_hint);
3324+
return emitApplySpecialForm(special_form->form(), apply, sv, type_hint);
33243325
}
33253326
auto args = getNamedValues(apply.inputs(), true);
33263327
auto kwargs = emitAttributes(apply.attributes());
@@ -3335,6 +3336,7 @@ struct to_ir {
33353336
std::shared_ptr<SugaredValue> emitApplySpecialForm(
33363337
Symbol form,
33373338
Apply& apply,
3339+
std::shared_ptr<SugaredValue> sv,
33383340
const TypePtr& type_hint = nullptr) {
33393341
switch (form) {
33403342
case prim::fork: {
@@ -3439,6 +3441,71 @@ struct to_ir {
34393441
return std::make_shared<SimpleValue>(
34403442
graph->insertNode(graph->createTuple(inp_values))->output());
34413443
}
3444+
case prim::LegacyTypedConstructor: {
3445+
// see legacy_tensor_generic_ctor_new
3446+
// These legacy constructors do not follow schemas that can be
3447+
// typed in native_functions.yaml / JIT type signature and are handled
3448+
// here. Only the two common cases are handled initially:
3449+
// "new(IntArrayRef size, *, Device? device=None)",
3450+
// "new(PyObject* data, *, Device? device=None)",
3451+
// Note: device argument is unused in the kernel
3452+
auto args = getValues(apply.inputs(), true);
3453+
auto kwargs = emitAttributes(apply.attributes());
3454+
auto get_base_error_msg = [&]() {
3455+
std::stringstream base_error_msg;
3456+
base_error_msg
3457+
<< "Legacy Tensor Constructor only supports two schemas in TorchScript: \n";
3458+
base_error_msg
3459+
<< "'new(IntArrayRef size, *, Device? device=None)',\n";
3460+
base_error_msg << "'new(PyObject* data, *, Device? device=None)\n'";
3461+
return base_error_msg;
3462+
};
3463+
if (kwargs.size() == 1 && kwargs[0].name() != "device") {
3464+
throw ErrorReport(apply)
3465+
<< get_base_error_msg().str() << "Got kwarg " << kwargs[0].name();
3466+
}
3467+
if (kwargs.size() > 1) {
3468+
throw ErrorReport(apply)
3469+
<< get_base_error_msg().str() << "Got multiple kwargs\n";
3470+
}
3471+
auto dtype = dynamic_cast<LegacyTensorConstructor*>(sv.get())->dtype();
3472+
auto dtype_ivalue = graph->insertConstant(dtype);
3473+
3474+
// supporting "new(IntArrayRef size, *, Device? device=None)", through
3475+
// empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout?
3476+
// layout=None, Device? device=None, bool? pin_memory=None,
3477+
// MemoryFormat? memory_format=None) -> Tensor
3478+
bool all_ints = std::all_of(args.begin(), args.end(), [](Value* v) {
3479+
return v->type()->cast<IntType>();
3480+
});
3481+
if (args.size() == 0) {
3482+
// empty inputs == torch.tensor([], dtype=....)
3483+
auto inp_list =
3484+
graph->insertNode(graph->createList(IntType::get(), {}))
3485+
->output();
3486+
return std::make_shared<SimpleValue>(graph->insert(
3487+
aten::tensor,
3488+
{inp_list},
3489+
{NamedValue(apply.range(), "dtype", dtype_ivalue)}));
3490+
} else if (all_ints) {
3491+
auto inp_list =
3492+
graph->insertNode(graph->createList(IntType::get(), args))
3493+
->output();
3494+
return std::make_shared<SimpleValue>(graph->insert(
3495+
aten::empty,
3496+
{inp_list},
3497+
{NamedValue(apply.range(), "dtype", dtype_ivalue)}));
3498+
} else if (args.size() == 1) {
3499+
return std::make_shared<SimpleValue>(graph->insert(
3500+
aten::tensor,
3501+
{args[0]},
3502+
{NamedValue(apply.range(), "dtype", dtype_ivalue)}));
3503+
} else {
3504+
throw ErrorReport(apply)
3505+
<< get_base_error_msg().str()
3506+
<< "Got multiple positional arguments that were not all integers";
3507+
}
3508+
}
34423509
case prim::isinstance: {
34433510
checkApplyNumInputs(apply, 2);
34443511
auto result = emitIsInstance(apply.inputs()[0], apply.inputs()[1]);

torch/csrc/jit/frontend/sugared_value.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#pragma once
2+
#include <c10/util/Optional.h>
23
#include <functional>
34
#include <memory>
45
#include <string>
@@ -618,6 +619,25 @@ struct TORCH_API SpecialFormValue : public SugaredValue {
618619
Symbol form_;
619620
};
620621

622+
struct TORCH_API LegacyTensorConstructor : public SpecialFormValue {
623+
LegacyTensorConstructor(Symbol form, at::ScalarType dtype, at::Device device)
624+
: SpecialFormValue(form), device_(device), dtype_(dtype) {}
625+
626+
static std::shared_ptr<LegacyTensorConstructor> create(
627+
Symbol form,
628+
at::ScalarType dtype,
629+
at::Device device) {
630+
return std::make_shared<LegacyTensorConstructor>(form, dtype, device);
631+
}
632+
at::ScalarType dtype() const {
633+
return dtype_;
634+
}
635+
636+
private:
637+
at::Device device_;
638+
at::ScalarType dtype_;
639+
};
640+
621641
// matched against for special handling of range expressions
622642
struct TORCH_API RangeValue : SugaredValue {
623643
RangeValue(

torch/csrc/jit/python/python_sugared_value.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <torch/csrc/jit/python/python_sugared_value.h>
22

3+
#include <ATen/core/interned_strings.h>
4+
#include <c10/core/ScalarType.h>
35
#include <pybind11/pytypes.h>
46
#include <torch/csrc/Dtype.h>
57
#include <torch/csrc/Layout.h>
@@ -1160,6 +1162,25 @@ std::shared_ptr<SugaredValue> toSugaredValue(
11601162
throw ErrorReport(loc) << "Cannot call a ScriptModule that is not"
11611163
<< " a submodule of the caller";
11621164
}
1165+
std::vector<std::pair<const char*, at::ScalarType>> tensor_names = {
1166+
{"BoolTensor", at::ScalarType::Bool},
1167+
{"LongTensor", at::ScalarType::Long},
1168+
{"ByteTensor", at::ScalarType::Byte},
1169+
{"CharTensor", at::ScalarType::Char},
1170+
{"DoubleTensor", at::ScalarType::Double},
1171+
{"FloatTensor", at::ScalarType::Float},
1172+
{"IntTensor", at::ScalarType::Int},
1173+
{"ShortTensor", at::ScalarType::Short},
1174+
{"HalfTensor", at::ScalarType::Half},
1175+
};
1176+
for (const auto& name : tensor_names) {
1177+
if (obj.ptr() == py::module::import("torch").attr(name.first).ptr()) {
1178+
// torch.LongTensor and other related functions create on cpu,
1179+
// TODO: add support for torch.cuda.LongTensor for gpu
1180+
return LegacyTensorConstructor::create(
1181+
prim::LegacyTypedConstructor, name.second, at::kCPU);
1182+
}
1183+
}
11631184

11641185
py::object builtin_name =
11651186
py::module::import("torch.jit._builtins").attr("_find_builtin")(obj);

0 commit comments

Comments
 (0)