From 18a6541b82ff92ddeae29aeade87e465068696da Mon Sep 17 00:00:00 2001 From: Norman Mu Date: Thu, 26 Jul 2018 14:05:27 -0700 Subject: [PATCH 01/17] Create IDEEP fallback operators for ctc decoder ops (#9847) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9847 CTCBeamSearchDecoder and CTCGreedyDecoder do not currently support IDEEP execution. Add fallback operators to allow IDEEP execution of models that use these operators. Reviewed By: yinghai Differential Revision: D9006234 fbshipit-source-id: fc539ba67b07d1f960d28564d8adde0be8690649 --- caffe2/ideep/operators/operator_fallback_ideep.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/caffe2/ideep/operators/operator_fallback_ideep.cc b/caffe2/ideep/operators/operator_fallback_ideep.cc index 0d8b6fd55b205b..16df4962b4284c 100644 --- a/caffe2/ideep/operators/operator_fallback_ideep.cc +++ b/caffe2/ideep/operators/operator_fallback_ideep.cc @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include #include #include @@ -112,4 +114,12 @@ REGISTER_IDEEP_OPERATOR( PRelu, IDEEPFallbackOp>); +// ctc decoder operators +REGISTER_IDEEP_OPERATOR( + CTCGreedyDecoder, + IDEEPFallbackOp>); +REGISTER_IDEEP_OPERATOR( + CTCBeamSearchDecoder, + IDEEPFallbackOp>); + } // namespace caffe2 From d1260d26fe4ef11c1831a4eab34508eff4ad95ff Mon Sep 17 00:00:00 2001 From: Fei Sun Date: Thu, 26 Jul 2018 14:38:27 -0700 Subject: [PATCH 02/17] Sleep before run (#9891) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9891 Add an argument to benchmark binary to specify the seconds to sleep before the run and after the warmup. Pull Request resolved: https://github.com/pytorch/pytorch/pull/9880 Reviewed By: llyfacebook Differential Revision: D9014254 Pulled By: sf-wind fbshipit-source-id: d5566186c8ed768f1e170e9266c5f2d6077391e0 --- binaries/benchmark_helper.cc | 6 +++++- binaries/benchmark_helper.h | 1 + binaries/caffe2_benchmark.cc | 7 ++++++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/binaries/benchmark_helper.cc b/binaries/benchmark_helper.cc index 52b51174cf34d1..27a593aaa81963 100644 --- a/binaries/benchmark_helper.cc +++ b/binaries/benchmark_helper.cc @@ -215,7 +215,8 @@ void runNetwork( const bool wipe_cache, const bool run_individual, const int warmup, - const int iter) { + const int iter, + const int sleep_before_run) { if (!net_def.has_name()) { net_def.set_name("benchmark"); } @@ -234,6 +235,9 @@ void runNetwork( if (wipe_cache) { caffe2::wipe_cache(); } + if (sleep_before_run > 0) { + sleep(sleep_before_run); + } LOG(INFO) << "Main runs."; CAFFE_ENFORCE( iter >= 0, diff --git a/binaries/benchmark_helper.h b/binaries/benchmark_helper.h index 0a52e16a50079c..5af2d91cec4bc7 100644 --- a/binaries/benchmark_helper.h +++ b/binaries/benchmark_helper.h @@ -96,4 +96,5 @@ void runNetwork( const bool, const bool, const int, + const int, const int); diff --git a/binaries/caffe2_benchmark.cc b/binaries/caffe2_benchmark.cc index 729479a17c7598..230210644947cd 100644 --- a/binaries/caffe2_benchmark.cc +++ b/binaries/caffe2_benchmark.cc @@ -62,6 +62,10 @@ CAFFE2_DEFINE_bool( run_individual, false, "Whether to benchmark individual operators."); +CAFFE2_DEFINE_int( + sleep_before_run, + 0, + "The seconds to sleep before starting the benchmarking."); CAFFE2_DEFINE_bool( text_output, false, @@ -115,7 +119,8 @@ int main(int argc, char** argv) { caffe2::FLAGS_wipe_cache, caffe2::FLAGS_run_individual, caffe2::FLAGS_warmup, - caffe2::FLAGS_iter); + caffe2::FLAGS_iter, + caffe2::FLAGS_sleep_before_run); writeOutput( workspace, From d65c667f2895ed71081d0b16fbfdf27054d5fd4d Mon Sep 17 00:00:00 2001 From: Gregory Chanan Date: Thu, 26 Jul 2018 15:53:47 -0700 Subject: [PATCH 03/17] Avoid divide-by-zero when hamming_window window length is 0. Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9896 Reviewed By: ezyang Differential Revision: D9018572 Pulled By: gchanan fbshipit-source-id: fa314687973124165bffb3084932d8ab6d872a93 --- aten/src/ATen/native/TensorFactories.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 1ddac71cf299b0..d6ebbd4573a70c 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -581,6 +581,9 @@ Tensor hamming_window( double beta, const TensorOptions& options) { window_function_checks("hamming_window", options, window_length); + if (window_length == 0) { + return native::empty({0}, options); + } if (window_length == 1) { return native::ones({1}, options); } From 9df9c46992d596cdbc74ef94ead6d628c8e54c08 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Thu, 26 Jul 2018 17:03:11 -0700 Subject: [PATCH 04/17] fix loading 1dim tensor from 0.3.* to 0dim tensor (#9781) Summary: This PR fixes #9743 . Adding backward support when loading a checkpoint from 0.3.* with 1dim tensor, they are now 0 dim tensor in 0.4+. Pull Request resolved: https://github.com/pytorch/pytorch/pull/9781 Differential Revision: D8988196 Pulled By: ailzhang fbshipit-source-id: a7a1bc771d597394208430575d5a4d23b9653fef --- torch/nn/modules/module.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index a00ff3dd9c268c..61c93a7e810f98 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -642,6 +642,10 @@ def _load_from_state_dict(self, state_dict, prefix, metadata, strict, missing_ke if key in state_dict: input_param = state_dict[key] + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + if input_param.shape != param.shape: # local shape should match the one in checkpoint error_msgs.append('size mismatch for {}: copying a param of {} from checkpoint, ' From b7b61a8eb4402129710df7f18a16b1931eff9c84 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 26 Jul 2018 17:55:53 -0700 Subject: [PATCH 05/17] Change expect, cast on Type to return shared pointers, make isSubtypeOf accept TypePtr (#9786) Summary: Follow up task of #9584. Commit 1: - change expect/cast to return shared pointers instead of raw pointer - isSubtypeOf accept TypePtr instead. Use `x->isSubtypeOf(NumberType::get())` rather than `x->isSubtypeOf(*NumberType::get())` Commit 2: - to address enable_shared_from_this pitfalls, we make the constructor private and expose the factory method to make sure user can only create it using our factory method. Pull Request resolved: https://github.com/pytorch/pytorch/pull/9786 Reviewed By: zdevito Differential Revision: D8980441 Pulled By: wanchaol fbshipit-source-id: e5c923fc57a701014310e77cf29985b43bb25364 --- torch/csrc/jit/argument_spec.h | 2 +- torch/csrc/jit/autodiff.cpp | 2 +- torch/csrc/jit/constants.cpp | 8 +- torch/csrc/jit/export.cpp | 2 +- torch/csrc/jit/fusion_compiler.h | 2 +- torch/csrc/jit/ir.cpp | 4 +- torch/csrc/jit/ir.h | 10 +- torch/csrc/jit/operator.cpp | 6 +- torch/csrc/jit/passes/erase_number_types.cpp | 4 +- torch/csrc/jit/passes/graph_fuser.cpp | 8 +- torch/csrc/jit/passes/lower_tuples.cpp | 4 +- torch/csrc/jit/passes/onnx/peephole.cpp | 4 +- torch/csrc/jit/passes/shape_analysis.cpp | 32 +-- torch/csrc/jit/python_ir.cpp | 4 +- torch/csrc/jit/register_prim_ops.cpp | 2 +- torch/csrc/jit/script/compiler.cpp | 24 +-- torch/csrc/jit/script/init.cpp | 4 +- torch/csrc/jit/script/module.h | 4 +- torch/csrc/jit/test_jit.cpp | 4 +- torch/csrc/jit/type.cpp | 12 +- torch/csrc/jit/type.h | 194 ++++++++++++------- 21 files changed, 195 insertions(+), 141 deletions(-) diff --git a/torch/csrc/jit/argument_spec.h b/torch/csrc/jit/argument_spec.h index 69b5036766e998..d6bd90cb708784 100644 --- a/torch/csrc/jit/argument_spec.h +++ b/torch/csrc/jit/argument_spec.h @@ -153,7 +153,7 @@ struct ArgumentInfo { operator TypePtr() const { if(!defined()) return DynamicType::get(); - return std::make_shared(type(), device(), sizes(), strides()); + return TensorType::create(type(), device(), sizes(), strides()); } private: // offsetinto sizes_strides() array where the sizes start for tensor j diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index ceb379a53925d4..f3e52c0171b121 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -35,7 +35,7 @@ bool isDifferentiable(Node * n) { if (!hasOneValuedInput(n, attr::alpha) || !hasOneValuedInput(n, attr::beta)) return false; } - auto isTensor = [](Value* v) { return v->type()->isSubtypeOf(*DynamicType::get()); }; + auto isTensor = [](Value* v) { return v->type()->isSubtypeOf(DynamicType::get()); }; if(!std::all_of(n->inputs().begin(), n->inputs().end(), isTensor) || !std::all_of(n->outputs().begin(), n->outputs().end(), isTensor)) diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp index 1c8bf928aab5dd..3c4ad0c130ea31 100644 --- a/torch/csrc/jit/constants.cpp +++ b/torch/csrc/jit/constants.cpp @@ -38,14 +38,14 @@ RegisterOperators reg({ prim::Constant, [](Node* node) -> Operation { TypePtr type = node->output()->type(); - if(type->isSubtypeOf(*DynamicType::get())) { + if(type->isSubtypeOf(DynamicType::get())) { auto t = autograd::make_variable(node->t(attr::value)); return [t](Stack& stack) { stack.push_back(t); return 0; }; } else if ( - type->isSubtypeOf(*NumberType::get()) && + type->isSubtypeOf(NumberType::get()) && node->kindOf(attr::value) == AttributeKind::i) { auto i = node->i(attr::value); return [i](Stack& stack) { @@ -53,14 +53,14 @@ RegisterOperators reg({ return 0; }; } else if ( - type->isSubtypeOf(*NumberType::get()) && + type->isSubtypeOf(NumberType::get()) && node->kindOf(attr::value) == AttributeKind::f) { auto f = node->f(attr::value); return [f](Stack& stack) { push(stack, f); return 0; }; - } else if(type->isSubtypeOf(*ListType::ofInts())) { + } else if(type->isSubtypeOf(ListType::ofInts())) { auto is = node->is(attr::value); return [is](Stack& stack) { push(stack, is); diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index 90120f1be0fb95..71dec999c40216 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -156,7 +156,7 @@ void addAttribute(onnx::NodeProto * n_p, jit::Node * n, jit::Symbol name, Export void encodeTypeProtoTensorType(onnx::TypeProtoTensor* tensor_type, Value* n) { onnx::TensorShapeProto* shape = tensor_type->mutable_shape(); - if (TensorType* node_type = n->type()->cast()) { + if (TensorTypePtr node_type = n->type()->cast()) { const std::vector& sizes = node_type->sizes(); for (std::int64_t s : sizes) { shape->add_dim(s); diff --git a/torch/csrc/jit/fusion_compiler.h b/torch/csrc/jit/fusion_compiler.h index 969cc1fc05566e..6c4759aefb692a 100644 --- a/torch/csrc/jit/fusion_compiler.h +++ b/torch/csrc/jit/fusion_compiler.h @@ -29,7 +29,7 @@ struct TensorDesc { : TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {} TensorDesc(const at::Tensor& t) : TensorDesc(t.type().scalarType(), t.sizes(), t.strides()) {} - TensorDesc(TensorType *type) + TensorDesc(TensorTypePtr type) : TensorDesc(type->scalarType(), type->sizes(), type->strides()) {} // number of dimensions after contiguity compression diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 3cd1d46b7df2af..6edf2bc176e364 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -240,7 +240,7 @@ static void checkSameDevice(const Node* node) { bool has_device = false; int device; auto checkValue = [&](const Value* v) { - if(TensorType* type = v->type()->cast()) { + if(TensorTypePtr type = v->type()->cast()) { if(!has_device) { has_device = true; device = type->device(); @@ -596,7 +596,7 @@ at::optional Node::get(Symbol name) const { // disambiguate via schema at::Tensor ten = t(name); const Argument* arg = findArgument(schema(), name).second; - if(arg->type->isSubtypeOf(*NumberType::get())) { + if(arg->type->isSubtypeOf(NumberType::get())) { return IValue(at::Scalar(ten)); } return IValue(ten); diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 8c940626699cd7..9af468e6ee06e7 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -181,7 +181,7 @@ struct Value { public: Value* setType(const TypePtr type); void inferTypeFrom(const at::Tensor& output) { - setType(std::make_shared(output)); + setType(TensorType::create(output)); } const TypePtr & type() const { JIT_ASSERT(type_ != nullptr); @@ -995,13 +995,13 @@ friend struct Block; } Node* createTuple(at::ArrayRef values) { auto types = fmap(values, [](Value* v) { return v->type(); }); - auto tt = std::make_shared(std::move(types)); + auto tt = TupleType::create(std::move(types)); auto n = create(prim::TupleConstruct, values); n->output()->setType(tt); return n; } Node* createTupleUnpack(Value * v) { - TupleType* tt = v->type()->expect(); + TupleTypePtr tt = v->type()->expect(); auto n = create(prim::TupleUnpack, {v}, 0); for(auto & element : tt->elements()) { n->addOutput()->setType(element); @@ -1011,9 +1011,9 @@ friend struct Block; Node* createList(const TypePtr& elem_type, at::ArrayRef values) { auto n = create(prim::ListConstruct, values); for(const auto & v : values) { - JIT_ASSERT(v->type()->isSubtypeOf(*elem_type)); + JIT_ASSERT(v->type()->isSubtypeOf(elem_type)); } - n->output()->setType(std::make_shared(elem_type)); + n->output()->setType(ListType::create(elem_type)); return n; } Node* createNumToTensor(Value* value) { diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp index 26e314c53eaa3a..560239948325a3 100644 --- a/torch/csrc/jit/operator.cpp +++ b/torch/csrc/jit/operator.cpp @@ -65,7 +65,7 @@ struct SchemaParser { void parseType(Argument& arg) { arg.type = parseBaseType(); if(L.nextIf('[')) { - arg.type = std::make_shared(arg.type); + arg.type = ListType::create(arg.type); if(L.cur().kind == TK_NUMBER) { arg.N = std::stoll(L.next().text()); } @@ -328,7 +328,7 @@ at::optional attributeKindOf(TypePtr type) { case TypeKind::FloatType: return AttributeKind::f; case TypeKind::NumberType: return AttributeKind::t; case TypeKind::ListType: - if(type->isSubtypeOf(*ListType::ofInts())) + if(type->isSubtypeOf(ListType::ofInts())) return AttributeKind::is; else return at::nullopt; @@ -338,7 +338,7 @@ at::optional attributeKindOf(TypePtr type) { } bool typeMatches(TypePtr actual, TypePtr formal) { - return actual->isSubtypeOf(*formal); + return actual->isSubtypeOf(formal); } bool Operator::matches(const Node* node) const { diff --git a/torch/csrc/jit/passes/erase_number_types.cpp b/torch/csrc/jit/passes/erase_number_types.cpp index 91f08c0941e7c2..0892b3e6cbdfe3 100644 --- a/torch/csrc/jit/passes/erase_number_types.cpp +++ b/torch/csrc/jit/passes/erase_number_types.cpp @@ -13,7 +13,7 @@ static void EraseNumberTypesOnBlock(Block* block) { case prim::Constant: { // remove primitive constants, replacing with tensor equivalent // ONNX does not support non-tensor constants - if(it->output()->type()->isSubtypeOf(*NumberType::get())) { + if(it->output()->type()->isSubtypeOf(NumberType::get())) { auto s = *constant_as(it->output()); WithInsertPoint guard(*it); Value* r = insertConstant(*block->owningGraph(), s.toTensor()); @@ -27,7 +27,7 @@ static void EraseNumberTypesOnBlock(Block* block) { } break; default: { for(auto o : it->outputs()) { - if (o->type()->isSubtypeOf(*NumberType::get())) { + if (o->type()->isSubtypeOf(NumberType::get())) { o->setType(TensorType::fromNumberType(o->type())); } } diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 660a4ac4e8ad38..745e910ccc9f9a 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -83,12 +83,12 @@ bool isSimpleMap(Node *node) { return false; // Make sure that the node doesn't broadcast. JIT_ASSERT(node->inputs().size() > 0); - TensorType* expected_type = node->inputs()[0]->type()->cast(); + TensorTypePtr expected_type = node->inputs()[0]->type()->cast(); if (!expected_type) return false; //type checking is intentionally dropped from isSimpleMap //isFusable is checking input/output types as there are some exceptions from allFloatIO requirement - static const auto equal_modulo_strides = [](TensorType* expected, const TypePtr& _actual) { - TensorType* actual = _actual->cast(); + static const auto equal_modulo_strides = [](const TensorTypePtr& expected, const TypePtr& _actual) { + TensorTypePtr actual = _actual->cast(); return actual && expected->device() == actual->device() && expected->sizes() == actual->sizes(); @@ -182,7 +182,7 @@ struct GraphFuser { } bool allOutputsHaveSameSize(Node * node) { - TensorType *tt_ptr = nullptr; + TensorTypePtr tt_ptr = nullptr; for (const auto i : node->inputs()) { auto cur_tt_ptr = i->type()->cast(); if (!cur_tt_ptr) { diff --git a/torch/csrc/jit/passes/lower_tuples.cpp b/torch/csrc/jit/passes/lower_tuples.cpp index 34f8c56f5607fe..89c74d4cf1fa7c 100644 --- a/torch/csrc/jit/passes/lower_tuples.cpp +++ b/torch/csrc/jit/passes/lower_tuples.cpp @@ -43,7 +43,7 @@ static void VisitNode(Node* n, Node* insert_point) { // flatten the input list op(a, tup, b) --> op(a, t0, t1, b) for(size_t i = 0; i < n->inputs().size();) { auto input = n->inputs()[i]; - if(TupleType* tt = input->type()->cast()) { + if(TupleTypePtr tt = input->type()->cast()) { JIT_ASSERTM(white_list.count(n->kind()) > 0, "tuple appears in op that does not forward tuples"); JIT_ASSERTM(input->node()->kind() == prim::TupleConstruct, "tuple use not matched to tuple construct"); for(size_t j = 0; j < tt->elements().size(); ++j) { @@ -68,7 +68,7 @@ static void VisitNode(Node* n, Node* insert_point) { // and: // tup = (t0, t1) // is placed at the current insertion point - if(TupleType* tt = output->type()->cast()) { + if(TupleTypePtr tt = output->type()->cast()) { JIT_ASSERTM(white_list.count(n->kind()) > 0, "tuple appears in op that does not forward tuples"); for(size_t j = 0; j < tt->elements().size(); j++) { n->insertOutput(i + 1 + j)->setType(tt->elements()[j]); diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 4620dd3812d56b..ea256f8e1f867f 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -265,13 +265,13 @@ void pushPackingPastRnn(Block *b) { // unhygenic way, Pytorch ends up propagating an incorrect type. // Until a long-term cleanup comes around, we can fix this by // resetting the size to the correct value. - TensorType* oldType = rnn->inputs()[0]->type()->cast(); + TensorTypePtr oldType = rnn->inputs()[0]->type()->cast(); if (oldType) { std::vector new_sizes; new_sizes.push_back(oldType->sizes()[0]); new_sizes.push_back(oldType->sizes()[1]); new_sizes.push_back(rnn->i(attr::hidden_size)); - TensorTypePtr newType = std::make_shared( + TensorTypePtr newType = TensorType::create( oldType->scalarType(), oldType->device(), new_sizes); next->outputs()[0]->setType(newType); } diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 3b18699f94ffcd..e6136b03c4414e 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -42,12 +42,12 @@ IValue representativeValue(Value* v) { if(auto iv = toIValue(v)) { return *iv; } - if (TensorType* type = type_->cast()) { + if (TensorTypePtr type = type_->cast()) { auto backend = type->device() == -1 ? at::kCPU : at::kCUDA; at::DeviceGuard device_guard(type->device()); auto& attype = at::getType(backend, type->scalarType()); return attype.tensor(type->sizes(), type->strides()).zero_(); - } else if (type_->isSubtypeOf(*FloatType::get())) { + } else if (type_->isSubtypeOf(FloatType::get())) { return 0.f; } // we should not get here because isValidArgumentForRunning should have @@ -63,8 +63,8 @@ void PropagateShapeOnBlock(Block * block, bool insert_expands=true); // for each node in the schema with type Tensor, extract the TensorType // returns at::nullopt if any Tensor in the schema does not have a known shape // ignores non-tensor in the list of inputs -at::optional> gatherTensorTypes(Node *node) { - std::vector tensor_types; +at::optional> gatherTensorTypes(Node *node) { + std::vector tensor_types; auto & schema = node->schema(); auto & args = schema.arguments; @@ -75,12 +75,12 @@ at::optional> gatherTensorTypes(Node *node) { size_t input_i = 0; for (auto& arg : args) { size_t consume_n; // how many tensors do we check for in the input list - if (arg.type->isSubtypeOf(*ListType::ofTensors())) { + if (arg.type->isSubtypeOf(ListType::ofTensors())) { // we have a list of tensor, there is only ever one list // so we calculte how many elements must be in it by how much bigger // or smaller the input list is compared to the arguments in the schema consume_n = node->inputs().size() + 1 - args.size(); - } else if (arg.type->isSubtypeOf(*DynamicType::get())) { + } else if (arg.type->isSubtypeOf(DynamicType::get())) { // a single Tensor for this argument consume_n = 1; } else { @@ -89,7 +89,7 @@ at::optional> gatherTensorTypes(Node *node) { } for(size_t j = 0; j < consume_n; j++) { // bail out if a tensor does not have a size - TensorType *type = node->input(input_i++)->type()->cast(); + TensorTypePtr type = node->input(input_i++)->type()->cast(); if (!type) return at::nullopt; tensor_types.push_back(type); @@ -117,10 +117,10 @@ bool mergeTypes(ArrayRef lhs, ArrayRef rhs, ArrayRef out void PropagateShapeOnNode(Node * node, bool insert_expands=true); -void broadcastBinary(Node *node, std::vector& types, size_t idx1, size_t idx2) { +void broadcastBinary(Node *node, std::vector& types, size_t idx1, size_t idx2) { auto expected_size = at::infer_size(types[idx1]->sizes(), types[idx2]->sizes()); auto broadcast = [&](size_t input_idx) { - TensorType* input_type = types.at(input_idx); + TensorTypePtr input_type = types.at(input_idx); if (input_type->sizes() == expected_size) return; auto graph = node->owningGraph(); @@ -178,14 +178,14 @@ bool isValidArgumentForRunning(Value* v) { // allow constants if(toIValue(v)) return true; - if(TensorType* tt = v->type()->cast()) { + if(TensorTypePtr tt = v->type()->cast()) { return !at::isIntegralType(tt->scalarType()); } - return v->type()->isSubtypeOf(*FloatType::get()); + return v->type()->isSubtypeOf(FloatType::get()); } bool isValidReturnForRunning(Value* v) { - return v->type()->isSubtypeOf(*DynamicType::get()) || - v->type()->isSubtypeOf(*NumberType::get()); + return v->type()->isSubtypeOf(DynamicType::get()) || + v->type()->isSubtypeOf(NumberType::get()); } bool canPropagateShapeByRunningIt(Node* node) { @@ -244,7 +244,7 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) { case prim::NumToTensor: return; // correct num type is already set case prim::Constant: { - if(node->output()->type()->isSubtypeOf(*DynamicType::get())) { + if(node->output()->type()->isSubtypeOf(DynamicType::get())) { node->output()->inferTypeFrom(node->t(attr::value)); } return; @@ -296,7 +296,7 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) { auto lhs_type = tensor_types.at(0); auto rhs_type = tensor_types.at(1); SHAPE_ASSERT(lhs_type->sizes().size() == 2 && rhs_type->sizes().size() == 2); - node->output()->setType(std::make_shared( + node->output()->setType(TensorType::create( lhs_type->scalarType(), lhs_type->device(), at::IntList{lhs_type->sizes().at(0), rhs_type->sizes().at(1)})); return; @@ -419,7 +419,7 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) { std::vector dim_vec = {(int64_t)tensor_types.at(0)->sizes().size()}; at::IntList dims(dim_vec); node->output()->setType( - std::make_shared(at::kLong, -1, dims)); + TensorType::create(at::kLong, -1, dims)); return; } else if (node->kind() == onnx::Reshape) { setUnshapedType(node); diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index c9e41e8a7eee26..05dbe341143d79 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -451,9 +451,9 @@ void initPythonIRBindings(PyObject * module_) { ; py::class_>(m, "DynamicType") - .def(py::init<>()); + .def(py::init([](){ return DynamicType::create(); })); py::class_>(m, "TupleType") - .def(py::init>()) + .def(py::init([](std::vector a){ return TupleType::create(a); })) .def("elements", [](TupleType &self){ std::vector types; for (auto type : self.elements()) { diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 010e0919f9cd03..b1fa7dc4c4185f 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -202,7 +202,7 @@ RegisterOperators reg({ prim::ListConstruct, [](Node* node) -> Operation { size_t num_inputs = node->inputs().size(); - ListType* lt = node->output()->type()->expect(); + ListTypePtr lt = node->output()->type()->expect(); if(IntType::get() == lt->getElementType()) { return [=](Stack& stack) { auto inputs = peekSlice(stack, 0, num_inputs, num_inputs); diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index be819775f1dc94..48f83881356c58 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -93,9 +93,9 @@ struct CastValue : public SugaredValue { throw ErrorReport(loc) << "expected a single argument for cast"; auto values = toValues(inputs); Value* input = values.at(0); - if(!input->type()->isSubtypeOf(*type)) { + if(!input->type()->isSubtypeOf(type)) { if(*type == *DynamicType::get()) { - if(!input->type()->isSubtypeOf(*NumberType::get())) { + if(!input->type()->isSubtypeOf(NumberType::get())) { throw ErrorReport(loc) << "expected a number"; } input = numToTensor(loc, input); @@ -244,7 +244,7 @@ struct Environment { throw ErrorReport(loc) << "Cannot re-assign '" << name << "' because it has type " << value->kind() << " and " << name << " is not a first-class value. Only reassignments to first-class values are allowed"; } - if(!as_simple_value->type()->isSubtypeOf(*unshapedType(simple_parent->type()))) { + if(!as_simple_value->type()->isSubtypeOf(unshapedType(simple_parent->type()))) { throw ErrorReport(loc) << "variable '" << name << "' previously has type " << simple_parent->type()->str() << " but is now being assigned to a value of type " << as_simple_value->type()->str(); } @@ -368,7 +368,7 @@ Value* createStack(Graph& g, const SourceRange& loc, at::ArrayRef inputs } static bool isTensorSubtype(Value* v) { - return v->type()->isSubtypeOf(*DynamicType::get()); + return v->type()->isSubtypeOf(DynamicType::get()); } at::optional> getIntListAttribute(at::optional N, Value* input) { @@ -538,12 +538,12 @@ at::optional> tryMatchSchema( // Allow tuples that only contain integers to turn into lists of integers if(*ListType::ofInts() == *arg.type && v.value->type()->kind() == TypeKind::TupleType && - v.value->type()->isSubtypeOf(*ListType::ofInts())) { + v.value->type()->isSubtypeOf(ListType::ofInts())) { auto unpacked = createTupleUnpack(v.value); v.value = graph.insertNode(graph.createList(IntType::get(), unpacked))->output(); } - if(!v.value->type()->isSubtypeOf(*arg.type)) { + if(!v.value->type()->isSubtypeOf(arg.type)) { err() << "expected a value of type " << arg.type->str() << " for argument '" << arg.name << "' but found " << v.value->type()->str() << "\n" << v.loc; @@ -551,7 +551,7 @@ at::optional> tryMatchSchema( } // we only support tensor lists for builtins, where they must be flattened - if(arg.type->isSubtypeOf(*ListType::ofTensors())) { + if(arg.type->isSubtypeOf(ListType::ofTensors())) { auto outputs = createTupleUnpack(v.value); flat_inputs.insert(flat_inputs.end(), outputs.begin(), outputs.end()); } else { @@ -663,7 +663,7 @@ static Value* ensureTensor(const SourceRange& range, Value* v) { } static Value* ensureInt(const SourceRange& range, Value* v) { - if(!v->type()->isSubtypeOf(*IntType::get())) { + if(!v->type()->isSubtypeOf(IntType::get())) { throw ErrorReport(range) << "expected a int but found a " << v->type()->str(); } @@ -778,7 +778,7 @@ struct to_ir { auto range = return_stmt.range(); size_t return_type_idx = 0; for (auto& r : results) { - if(r->type()->isSubtypeOf(*NumberType::get())) { + if(r->type()->isSubtypeOf(NumberType::get())) { graph->registerOutput(numToTensor(range, r)); } else { ensureTensor(range, r); @@ -787,7 +787,7 @@ struct to_ir { TypePtr type = DynamicType::get(); if (typed_def.schema) { type = typed_def.schema->returns.at(return_type_idx).type; - if (!r->type()->isSubtypeOf(*type)) { + if (!r->type()->isSubtypeOf(type)) { throw ErrorReport(return_stmt.range()) << "Return value at position " << return_type_idx << " was annotated as having type " << type->str() << " but is actually of type " << r->type()->str(); @@ -914,10 +914,10 @@ struct to_ir { Value* emitCond(Expr cond) { Value* v = emitExpr(cond, identity); - if(v->type()->isSubtypeOf(*DynamicType::get())) { + if(v->type()->isSubtypeOf(DynamicType::get())) { v = tensorToNum(cond.range(), v, IntType::get()); } - if(!v->type()->isSubtypeOf(*IntType::get())) { + if(!v->type()->isSubtypeOf(IntType::get())) { throw ErrorReport(cond) << "expected a tensor or integer expression for condition but found " << v->type()->str(); } return v; diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 576344427c0461..cb7893234dc747 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -86,7 +86,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { << "of arguments: expected " << arguments.size() << ", but got " << inputs.size(); for (size_t i = 0; i < arguments.size(); ++i) { - if (!inputs[i]->type()->isSubtypeOf(*arguments[i])) + if (!inputs[i]->type()->isSubtypeOf(arguments[i])) throw ErrorReport(loc) << "type mismatch at argument " << i << ": expected " << arguments[i]->str() << ", but got " << inputs[i]->type()->str(); } @@ -135,7 +135,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { // equivalent, but the PythonOp impl ends with an optional tuple unpack, so we need // to do it. for (auto & ret_type_elem : returns) { - if (!ret_type_elem->isSubtypeOf(*DynamicType::get())) { + if (!ret_type_elem->isSubtypeOf(DynamicType::get())) { throw ErrorReport(loc) << "Python functions can currently only return Tensors"; } } diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index 76518aaf1d26fa..1120d0bcaad740 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -119,14 +119,14 @@ struct Method { for (size_t i=0; i < retval->inputs().size(); ++i) { auto scalar_type = inputs[i].type().scalarType(); auto sizes = inputs[i].sizes(); - auto type = std::make_shared(scalar_type, -1, sizes); + auto type = torch::jit::TensorType::create(scalar_type, -1, sizes); retval->inputs()[i]->setType(type); } JIT_ASSERT(retval->outputs().size() == outputs.size()); for (size_t i=0; i < retval->outputs().size(); ++i) { auto scalar_type = outputs[i].type().scalarType(); auto sizes = outputs[i].sizes(); - auto type = std::make_shared(scalar_type, -1, sizes); + auto type = torch::jit::TensorType::create(scalar_type, -1, sizes); retval->outputs()[i]->setType(type); } return retval; diff --git a/torch/csrc/jit/test_jit.cpp b/torch/csrc/jit/test_jit.cpp index ecb8c9b3779816..8c9763f88353e5 100644 --- a/torch/csrc/jit/test_jit.cpp +++ b/torch/csrc/jit/test_jit.cpp @@ -641,7 +641,7 @@ std::string toString(std::shared_ptr& graph) { void testDifferentiate(std::ostream & out) { auto graph = std::make_shared(); at::ScalarType s = at::ScalarType::Float; - auto type = std::shared_ptr(new TensorType(s, -1, {2, 3, 4}, {12, 4, 1})); + auto type = TensorType::create(s, -1, {2, 3, 4}, {12, 4, 1}); // Build up a fake graph auto a = SymbolicVariable::asNewInput(*graph, type); @@ -668,7 +668,7 @@ void testDifferentiate(std::ostream & out) { void testDifferentiateWithRequiresGrad(std::ostream & out) { auto graph = std::make_shared(); at::ScalarType s = at::ScalarType::Float; - auto type = std::shared_ptr(new TensorType(s, -1, {2, 3, 4}, {12, 4, 1})); + auto type = TensorType::create(s, -1, {2, 3, 4}, {12, 4, 1}); // Build up a fake graph auto a = SymbolicVariable::asNewInput(*graph, type); diff --git a/torch/csrc/jit/type.cpp b/torch/csrc/jit/type.cpp index b657d4e935f17a..bf28588ad7eca6 100644 --- a/torch/csrc/jit/type.cpp +++ b/torch/csrc/jit/type.cpp @@ -45,29 +45,29 @@ std::ostream& operator<<(std::ostream & out, const Type & t) { } TypePtr DynamicType::get() { - static auto value = std::make_shared(); + static auto value = DynamicType::create(); return value; } TypePtr NumberType::get() { - static auto value = std::make_shared(); + static auto value = NumberType::create(); return value; } TypePtr IntType::get() { - static auto value = std::make_shared(); + static auto value = IntType::create(); return value; } TypePtr FloatType::get() { - static auto value = std::make_shared(); + static auto value = FloatType::create(); return value; } TypePtr ListType::ofTensors() { - static auto value = std::make_shared(DynamicType::get()); + static auto value = ListType::create(DynamicType::get()); return value; } TypePtr ListType::ofInts() { - static auto value = std::make_shared(IntType::get()); + static auto value = ListType::create(IntType::get()); return value; } diff --git a/torch/csrc/jit/type.h b/torch/csrc/jit/type.h index 177833f23d938e..dc2cea1fa50b94 100644 --- a/torch/csrc/jit/type.h +++ b/torch/csrc/jit/type.h @@ -31,7 +31,7 @@ struct Type; using TypePtr = std::shared_ptr; -struct TORCH_API Type { +struct TORCH_API Type : std::enable_shared_from_this { private: TypeKind kind_; @@ -44,8 +44,8 @@ struct TORCH_API Type { // subtyping relation. By default, we return true for the case // when the type is exactly equal - virtual bool isSubtypeOf(const Type& rhs) const { - return *this == rhs; + virtual bool isSubtypeOf(const TypePtr rhs) const { + return *this == *rhs; } // user-friendly form of the type, separate from // operator<< which is verbose and unambiguous @@ -58,26 +58,26 @@ struct TORCH_API Type { // Dynamically cast this object to the subclass indicated by the // template variable, returning nullptr if the cast is invalid.. template - T* cast() { + std::shared_ptr cast() { if (T::Kind == kind()) - return static_cast(this); + return std::static_pointer_cast(shared_from_this()); return nullptr; } template - const T* cast() const { + std::shared_ptr cast() const { if (T::Kind == kind()) - return static_cast(this); + return std::static_pointer_cast(shared_from_this()); return nullptr; } template - T* expect() { + std::shared_ptr expect() { JIT_ASSERT(T::Kind == kind()); - return static_cast(this); + return std::static_pointer_cast(shared_from_this()); } template - const T* expect() const { + std::shared_ptr expect() const { JIT_ASSERT(T::Kind == kind()); - return static_cast(this); + return std::static_pointer_cast(shared_from_this()); } virtual ~Type() {} }; @@ -86,10 +86,15 @@ inline bool operator!=(const Type & lhs, const Type & rhs) { return !(lhs == rhs); } +struct DynamicType; +using DynamicTypePtr = std::shared_ptr; // This node represents a single Tensor value, with an unknown shape. struct TORCH_API DynamicType : public Type { - DynamicType() - : Type(TypeKind::DynamicType) {} + template + static DynamicTypePtr create( T&& ... all ) { + return DynamicTypePtr(new DynamicType( std::forward(all)... )); + } + bool operator==(const Type& rhs) const override { return rhs.kind() == kind(); } @@ -99,6 +104,9 @@ struct TORCH_API DynamicType : public Type { static const TypeKind Kind = TypeKind::DynamicType; // global singleton static TypePtr get(); +private: + DynamicType() + : Type(TypeKind::DynamicType) {} }; struct TensorType; @@ -106,21 +114,18 @@ using TensorTypePtr = std::shared_ptr; // This node represents a single Tensor value with a specific size struct TORCH_API TensorType : public Type { friend struct Type; - TensorType(const at::Tensor& tensor) - : Type(TypeKind::TensorType) - , scalar_type_(tensor.type().scalarType()) - , device_(tensor.type().is_cuda() ? tensor.get_device() : -1) - , sizes_(tensor.sizes()) - , strides_(tensor.strides()) {} - TensorType(at::ScalarType scalar_type, int device, at::IntList sizes) - : TensorType(scalar_type, device, sizes, TensorType::contiguousStridesOf(sizes)) {} - TensorType(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides) - : Type(TypeKind::TensorType) - , scalar_type_(scalar_type) - , device_(device) - , sizes_(sizes) - , strides_(strides) - {} + template + static TensorTypePtr create( T&& ... all ) { + return TensorTypePtr(new TensorType( std::forward(all)... )); + } + + // overloaded create variadic template argument as it could not distinguish initializer list + static TensorTypePtr create(at::ScalarType scalar_type, int device, at::IntList sizes) { + return TensorTypePtr(new TensorType(scalar_type, device, sizes)); + } + static TensorTypePtr create(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides) { + return TensorTypePtr(new TensorType(scalar_type, device, sizes, strides)); + } static const TypeKind Kind = TypeKind::TensorType; @@ -130,7 +135,7 @@ struct TORCH_API TensorType : public Type { const std::vector& strides() const { return strides_; } TypePtr withSizesStrides(at::IntList sizes, at::IntList strides) const { - return std::make_shared(scalar_type_, device_, sizes, strides); + return TensorType::create(scalar_type_, device_, sizes, strides); } TypePtr withSizes(at::IntList sizes) const { @@ -138,13 +143,13 @@ struct TORCH_API TensorType : public Type { } TensorTypePtr contiguous() const { - auto t = std::make_shared(*this); + auto t = TensorType::create(*this); t->strides_ = TensorType::contiguousStridesOf(sizes_); return t; } TensorTypePtr toScalarType(at::ScalarType type){ - auto t = std::make_shared(*this); + auto t = TensorType::create(*this); t->scalar_type_ = type; return t; } @@ -158,8 +163,8 @@ struct TORCH_API TensorType : public Type { strides() == rt->strides() && device() == rt->device(); } - bool isSubtypeOf(const Type& rhs) const override { - return *this == rhs || rhs.kind() == TypeKind::DynamicType; + bool isSubtypeOf(const TypePtr rhs) const override { + return *this == *rhs || rhs->kind() == TypeKind::DynamicType; } std::string str() const override { // str is used for user-facing error messages, where we @@ -176,6 +181,21 @@ struct TORCH_API TensorType : public Type { static TypePtr fromNumberType(TypePtr typ); private: + TensorType(const at::Tensor& tensor) + : Type(TypeKind::TensorType) + , scalar_type_(tensor.type().scalarType()) + , device_(tensor.type().is_cuda() ? tensor.get_device() : -1) + , sizes_(tensor.sizes()) + , strides_(tensor.strides()) {} + TensorType(at::ScalarType scalar_type, int device, at::IntList sizes) + : TensorType(scalar_type, device, sizes, TensorType::contiguousStridesOf(sizes)) {} + TensorType(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides) + : Type(TypeKind::TensorType) + , scalar_type_(scalar_type) + , device_(device) + , sizes_(sizes) + , strides_(strides) + {} static std::vector contiguousStridesOf(at::IntList sizes) { std::vector strides(sizes.size()); if(sizes.size() == 0) // zero-dim case @@ -192,11 +212,15 @@ struct TORCH_API TensorType : public Type { std::vector strides_; }; +struct ListType; +using ListTypePtr = std::shared_ptr; + struct TORCH_API ListType : public Type { friend struct Type; - static const TypeKind Kind = TypeKind::ListType; - ListType(TypePtr elem) - : Type(TypeKind::ListType), elem(elem) {} + template + static ListTypePtr create( T&& ... all ) { + return ListTypePtr(new ListType( std::forward(all)... )); + } bool operator==(const Type& rhs) const override { if(auto rhs_ = rhs.cast()) { return *getElementType() == *rhs_->getElementType(); @@ -215,35 +239,41 @@ struct TORCH_API ListType : public Type { static TypePtr ofTensors(); static TypePtr ofInts(); private: + ListType(TypePtr elem) + : Type(TypeKind::ListType), elem(elem) {} + static const TypeKind Kind = TypeKind::ListType; TypePtr elem; }; +struct TupleType; +using TupleTypePtr = std::shared_ptr; + struct TORCH_API TupleType : public Type { friend struct Type; - TupleType(std::vector elements_) - : Type(TypeKind::TupleType) - , elements_(std::move(elements_)) {} - static const TypeKind Kind = TypeKind::TupleType; + template + static TupleTypePtr create( T&& ... all ) { + return TupleTypePtr(new TupleType( std::forward(all)... )); + } at::ArrayRef elements() const { return elements_; } bool operator==(const Type& rhs) const override { - return compare(rhs, [](const Type& a, const Type& b) { - return a == b; + return compare(rhs, [](const TypePtr a, const TypePtr b) { + return *a == *b; }); } - bool isSubtypeOf(const Type& rhs) const override { + bool isSubtypeOf(const TypePtr rhs) const override { // e.g. (Tensor, Tensor, Tensor) <: List[Tensor] - if(auto lt = rhs.cast()) { + if(auto lt = rhs->cast()) { for(auto e : elements()) { - if(!e->isSubtypeOf(*lt->getElementType())) + if(!e->isSubtypeOf(lt->getElementType())) return false; } return true; } // co-variant rules for tuples - return compare(rhs, [](const Type& a, const Type&b) { - return a.isSubtypeOf(b); + return compare(*rhs, [](const TypePtr a, const TypePtr b) { + return a->isSubtypeOf(b); }); } std::string str() const override { @@ -258,7 +288,12 @@ struct TORCH_API TupleType : public Type { return ss.str(); } private: - bool compare(const Type& rhs, std::function fn) const { + TupleType(std::vector elements_) + : Type(TypeKind::TupleType) + , elements_(std::move(elements_)) {} + static const TypeKind Kind = TypeKind::TupleType; + + bool compare(const Type& rhs, std::function fn) const { if(rhs.kind() != kind()) return false; const auto & l_elements = elements(); @@ -266,7 +301,7 @@ struct TORCH_API TupleType : public Type { if(l_elements.size() != r_elements.size()) return false; for(size_t i = 0; i < l_elements.size(); ++i) { - if(!fn(*l_elements[i], *r_elements[i])) + if(!fn(l_elements[i], r_elements[i])) return false; } return true; @@ -274,10 +309,14 @@ struct TORCH_API TupleType : public Type { std::vector elements_; }; +struct NumberType; +using NumberTypePtr = std::shared_ptr; // This node represents a Python number value struct TORCH_API NumberType : public Type { - NumberType() - : Type(TypeKind::NumberType) {} + template + static NumberTypePtr create( T&& ... all ) { + return NumberTypePtr(new NumberType( std::forward(all)... )); + } bool operator==(const Type& rhs) const override { return rhs.kind() == kind(); } @@ -287,42 +326,59 @@ struct TORCH_API NumberType : public Type { static const TypeKind Kind = TypeKind::NumberType; // global singleton static TypePtr get(); +private: + NumberType() + : Type(TypeKind::NumberType) {} }; +struct FloatType; +using FloatTypePtr = std::shared_ptr; // This node represents a Python float number value struct TORCH_API FloatType : public Type { - FloatType() - : Type(TypeKind::FloatType) {} + template + static FloatTypePtr create( T&& ... all ) { + return FloatTypePtr(new FloatType( std::forward(all)... )); + } bool operator==(const Type& rhs) const override { return rhs.kind() == kind(); } std::string str() const override { return "float"; } - bool isSubtypeOf(const Type& rhs) const override { - return *this == rhs || rhs.kind() == TypeKind::NumberType; + bool isSubtypeOf(const TypePtr rhs) const override { + return *this == *rhs || rhs->kind() == TypeKind::NumberType; } static const TypeKind Kind = TypeKind::FloatType; // global singleton static TypePtr get(); +private: + FloatType() + : Type(TypeKind::FloatType) {} }; +struct IntType; +using IntTypePtr = std::shared_ptr; // This node represents a Python int number value struct TORCH_API IntType : public Type { - IntType() - : Type(TypeKind::IntType) {} + template + static IntTypePtr create( T&& ... all ) { + return IntTypePtr(new IntType( std::forward(all)... )); + } bool operator==(const Type& rhs) const override { return rhs.kind() == kind(); } std::string str() const override { return "int"; } - bool isSubtypeOf(const Type& rhs) const override { - return *this == rhs || rhs.kind() == TypeKind::NumberType; + bool isSubtypeOf(const TypePtr rhs) const override { + return *this == *rhs || rhs->kind() == TypeKind::NumberType; } static const TypeKind Kind = TypeKind::IntType; // global singleton static TypePtr get(); +private: + IntType() + : Type(TypeKind::IntType) {} }; @@ -331,10 +387,10 @@ TORCH_API std::ostream& operator<<(std::ostream & out, const Type & t); // e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...) inline TypePtr unshapedType(const TypePtr& type) { - if(TupleType* t = type->cast()) { - return std::make_shared(fmap(t->elements(), unshapedType)); - } else if(ListType* t = type->cast()) { - return std::make_shared(unshapedType(t->getElementType())); + if(TupleTypePtr t = type->cast()) { + return TupleType::create(fmap(t->elements(), unshapedType)); + } else if(ListTypePtr t = type->cast()) { + return ListType::create(unshapedType(t->getElementType())); } else if(type->kind() == TypeKind::TensorType) { return DynamicType::get(); } else { @@ -343,13 +399,11 @@ inline TypePtr unshapedType(const TypePtr& type) { } inline TypePtr TensorType::fromNumberType(TypePtr typ) { - JIT_ASSERT(typ->isSubtypeOf(*NumberType::get())); - if(typ->isSubtypeOf(*IntType::get())) { - TensorType tt(at::kLong, -1, {}); - return std::make_shared(std::move(tt)); - } else if(typ->isSubtypeOf(*FloatType::get())) { - TensorType tt(at::kFloat, -1, {}); - return std::make_shared(std::move(tt)); + JIT_ASSERT(typ->isSubtypeOf(NumberType::get())); + if(typ->isSubtypeOf(IntType::get())) { + return TensorType::create(at::kLong, -1, {}); + } else if(typ->isSubtypeOf(FloatType::get())) { + return TensorType::create(at::kFloat, -1, {}); } AT_ERROR("unknown number type", typ->str()); } From e7ab093d93c968158430ad02cf52a22f829a0816 Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Thu, 26 Jul 2018 18:08:28 -0700 Subject: [PATCH 06/17] Simplify order switch operators (#9581) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9581 Mostly to simplify code. Should also improve performance but order switch ops don't take much time anyway. Reviewed By: viswanathgs Differential Revision: D8909766 fbshipit-source-id: 17a302d5bf4aba2755d88223fc01a41fd72c5919 --- caffe2/operators/order_switch_ops.cc | 31 ++++------- .../python/operator_test/order_switch_test.py | 52 +++++++++++++++++++ 2 files changed, 61 insertions(+), 22 deletions(-) create mode 100644 caffe2/python/operator_test/order_switch_test.py diff --git a/caffe2/operators/order_switch_ops.cc b/caffe2/operators/order_switch_ops.cc index 11cc6dedc24f9f..7296f9a74afa51 100644 --- a/caffe2/operators/order_switch_ops.cc +++ b/caffe2/operators/order_switch_ops.cc @@ -10,16 +10,10 @@ bool NHWC2NCHWOp::RunOnDevice() { const int N = X.dim32(0), H = X.dim32(1), W = X.dim32(2), C = X.dim32(3); Y->Resize(N, C, H, W); const float* Xdata = X.data(); - float* Ydata = Y->mutable_data(); - for (int n = 0; n < N; ++n) { - for (int h = 0; h < H; ++h) { - for (int w = 0; w < W; ++w) { - for (int c = 0; c < C; ++c) { - Ydata[((n * C + c) * H + h) * W + w] = *(Xdata++); - } - } - } - } + float* Ydata = Y->template mutable_data(); + std::array dims = {N, H, W, C}; + std::array axes = {0, 3, 1, 2}; + math::Transpose(4, dims.data(), axes.data(), Xdata, Ydata, &context_); return true; } @@ -31,20 +25,13 @@ bool NCHW2NHWCOp::RunOnDevice() { const int N = X.dim32(0), C = X.dim32(1), H = X.dim32(2), W = X.dim32(3); Y->Resize(N, H, W, C); const float* Xdata = X.data(); - float* Ydata = Y->mutable_data(); - for (int n = 0; n < N; ++n) { - for (int c = 0; c < C; ++c) { - for (int h = 0; h < H; ++h) { - for (int w = 0; w < W; ++w) { - Ydata[((n * H + h) * W + w) * C + c] = *(Xdata++); - } - } - } - } + float* Ydata = Y->template mutable_data(); + std::array dims = {N, C, H, W}; + std::array axes = {0, 2, 3, 1}; + math::Transpose(4, dims.data(), axes.data(), Xdata, Ydata, &context_); return true; } - REGISTER_CPU_OPERATOR(NHWC2NCHW, NHWC2NCHWOp); REGISTER_CPU_OPERATOR(NCHW2NHWC, NCHW2NHWCOp); @@ -102,4 +89,4 @@ class GetNCHW2NHWCGradient : public GradientMakerBase { } }; REGISTER_GRADIENT(NCHW2NHWC, GetNCHW2NHWCGradient); -} // namespace caffe2 +} // namespace caffe2 diff --git a/caffe2/python/operator_test/order_switch_test.py b/caffe2/python/operator_test/order_switch_test.py new file mode 100644 index 00000000000000..71ba64e40f3ffb --- /dev/null +++ b/caffe2/python/operator_test/order_switch_test.py @@ -0,0 +1,52 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np +import caffe2.python.hypothesis_test_util as hu +from caffe2.python import core +from hypothesis import given +import hypothesis.strategies as st + + +class OrderSwitchOpsTest(hu.HypothesisTestCase): + @given( + n=st.integers(1, 5), + c=st.integers(1, 5), + h=st.integers(1, 5), + w=st.integers(1, 5), + **hu.gcs) + def test_nchw2nhwc(self, n, c, h, w, gc, dc): + X = np.random.randn(n, c, h, w).astype(np.float32) + + op = core.CreateOperator("NCHW2NHWC", ["X"], ["Y"], + device_option=gc) + + def nchw2nhwc_ref(X): + X_reshaped = X.transpose((0, 2, 3, 1)) + return (X_reshaped,) + + self.assertReferenceChecks(gc, op, [X], nchw2nhwc_ref) + self.assertGradientChecks(gc, op, [X], 0, [0]) + self.assertDeviceChecks(dc, op, [X], [0]) + + @given( + n=st.integers(1, 5), + c=st.integers(1, 5), + h=st.integers(1, 5), + w=st.integers(1, 5), + **hu.gcs) + def test_nhwc2nchw(self, n, c, h, w, gc, dc): + X = np.random.randn(n, h, w, c).astype(np.float32) + + op = core.CreateOperator("NHWC2NCHW", ["X"], ["Y"], + device_option=gc) + + def nhwc2nchw_ref(X): + X_reshaped = X.transpose((0, 3, 1, 2)) + return (X_reshaped,) + + self.assertReferenceChecks(gc, op, [X], nhwc2nchw_ref) + self.assertGradientChecks(gc, op, [X], 0, [0]) + self.assertDeviceChecks(dc, op, [X], [0]) From c045e969b67e232165c47fdf3507429ce7d2f432 Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Thu, 26 Jul 2018 18:39:51 -0700 Subject: [PATCH 07/17] Use qualified name at::Half in Dispatch.h (#9848) Summary: This makes AT_DISPATCH_ALL_TYPES_AND_HALF valid outside of the at namespace. See https://github.com/pytorch/extension-cpp/issues/15 Pull Request resolved: https://github.com/pytorch/pytorch/pull/9848 Differential Revision: D9006921 Pulled By: colesbury fbshipit-source-id: a6e4f097a9d6fb85c921e1c9b9ea25d0f2db06dc --- aten/src/ATen/Dispatch.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index 6cd8722316297e..be06656a3dee7c 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -17,7 +17,7 @@ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ } \ }() @@ -27,9 +27,9 @@ switch (the_type.scalarType()) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, Half, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ } \ }() @@ -43,7 +43,7 @@ AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ } \ }() @@ -59,7 +59,7 @@ AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ } \ }() @@ -74,8 +74,8 @@ AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, Half, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ } \ }() From dfa0af093d6ff03460a2a7167eed8deed99aca72 Mon Sep 17 00:00:00 2001 From: Yi Cheng Date: Thu, 26 Jul 2018 18:54:07 -0700 Subject: [PATCH 08/17] Move predictor into caffe2/caffe2/predictor (#9548) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9548 Pull Request resolved: https://github.com/pytorch/translate/pull/157 One part of refactor predictor. Move all the files into predictor dir. Reviewed By: highker Differential Revision: D8845276 fbshipit-source-id: 1e917464b0c8a042f025128a082c784eaa3b7013 --- binaries/predictor_verifier.cc | 2 +- caffe2/CMakeLists.txt | 1 + caffe2/mobile/contrib/ios/ios_caffe.cc | 2 +- caffe2/mobile/contrib/ios/ios_caffe.h | 2 +- caffe2/mobile/contrib/ios/ios_caffe_predictor.h | 2 +- caffe2/mobile/contrib/opengl/core/GLPredictor.h | 2 +- caffe2/mobile/contrib/opengl/core/rewrite_net.h | 2 +- caffe2/onnx/backend_rep.h | 2 +- caffe2/predictor/CMakeLists.txt | 13 +++++++++++++ caffe2/{core => predictor}/predictor.cc | 2 +- caffe2/{core => predictor}/predictor.h | 2 +- caffe2/{core => predictor}/predictor_config.h | 0 caffe2/{core => predictor}/predictor_test.cc | 2 +- caffe2/{core => predictor}/predictor_utils.cc | 2 +- caffe2/{core => predictor}/predictor_utils.h | 0 caffe2/python/pybind_state.cc | 2 +- 16 files changed, 26 insertions(+), 12 deletions(-) create mode 100644 caffe2/predictor/CMakeLists.txt rename caffe2/{core => predictor}/predictor.cc (99%) rename caffe2/{core => predictor}/predictor.h (97%) rename caffe2/{core => predictor}/predictor_config.h (100%) rename caffe2/{core => predictor}/predictor_test.cc (99%) rename caffe2/{core => predictor}/predictor_utils.cc (98%) rename caffe2/{core => predictor}/predictor_utils.h (100%) diff --git a/binaries/predictor_verifier.cc b/binaries/predictor_verifier.cc index e82a8e9d2cec85..e8e29f29559cee 100644 --- a/binaries/predictor_verifier.cc +++ b/binaries/predictor_verifier.cc @@ -16,7 +16,7 @@ #include "caffe2/core/flags.h" #include "caffe2/core/init.h" -#include "caffe2/core/predictor.h" +#include "caffe2/predictor/predictor.h" #include "caffe2/utils/proto_utils.h" CAFFE2_DEFINE_string(init_net, "", "The given path to the init protobuffer."); diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index c92357b44680f5..088390c1752412 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -65,6 +65,7 @@ if(BUILD_CAFFE2) add_subdirectory(proto) add_subdirectory(contrib) add_subdirectory(core) + add_subdirectory(predictor) add_subdirectory(core/nomnigraph) add_subdirectory(core/dispatch) if (USE_NVRTC) diff --git a/caffe2/mobile/contrib/ios/ios_caffe.cc b/caffe2/mobile/contrib/ios/ios_caffe.cc index 12e0e5598c6aa0..f1bcf4a5b1087a 100644 --- a/caffe2/mobile/contrib/ios/ios_caffe.cc +++ b/caffe2/mobile/contrib/ios/ios_caffe.cc @@ -1,8 +1,8 @@ #include "ios_caffe.h" -#include "caffe2/core/predictor.h" #include "caffe2/core/tensor.h" #include "caffe2/mobile/contrib/ios/ios_caffe_predictor.h" +#include "caffe2/predictor/predictor.h" Caffe2IOSPredictor* MakeCaffe2Predictor(const std::string& init_net_str, const std::string& predict_net_str, diff --git a/caffe2/mobile/contrib/ios/ios_caffe.h b/caffe2/mobile/contrib/ios/ios_caffe.h index 3fbd235a74f706..7b5f8170405b6f 100644 --- a/caffe2/mobile/contrib/ios/ios_caffe.h +++ b/caffe2/mobile/contrib/ios/ios_caffe.h @@ -3,9 +3,9 @@ #include #include -#include "caffe2/core/predictor.h" #include "caffe2/mobile/contrib/ios/ios_caffe_defines.h" #include "caffe2/mobile/contrib/ios/ios_caffe_predictor.h" +#include "caffe2/predictor/predictor.h" extern "C" { diff --git a/caffe2/mobile/contrib/ios/ios_caffe_predictor.h b/caffe2/mobile/contrib/ios/ios_caffe_predictor.h index 0b065d3c426956..a51711ce0558e5 100644 --- a/caffe2/mobile/contrib/ios/ios_caffe_predictor.h +++ b/caffe2/mobile/contrib/ios/ios_caffe_predictor.h @@ -3,8 +3,8 @@ #include #include "caffe2/core/net.h" -#include "caffe2/core/predictor.h" #include "caffe2/mobile/contrib/ios/ios_caffe_defines.h" +#include "caffe2/predictor/predictor.h" struct Tensor { std::vector dims; diff --git a/caffe2/mobile/contrib/opengl/core/GLPredictor.h b/caffe2/mobile/contrib/opengl/core/GLPredictor.h index 2806f8a0408293..24c319759bd7d1 100644 --- a/caffe2/mobile/contrib/opengl/core/GLPredictor.h +++ b/caffe2/mobile/contrib/opengl/core/GLPredictor.h @@ -3,7 +3,7 @@ #include "GLImage.h" #include "caffe2/core/net.h" -#include "caffe2/core/predictor.h" +#include "caffe2/predictor/predictor.h" namespace caffe2 { class GLPredictor : public Predictor { diff --git a/caffe2/mobile/contrib/opengl/core/rewrite_net.h b/caffe2/mobile/contrib/opengl/core/rewrite_net.h index c3c47d63f75065..d0bc921a8876ca 100644 --- a/caffe2/mobile/contrib/opengl/core/rewrite_net.h +++ b/caffe2/mobile/contrib/opengl/core/rewrite_net.h @@ -1,7 +1,7 @@ #pragma once #include "GLPredictor.h" -#include "caffe2/core/predictor.h" +#include "caffe2/predictor/predictor.h" namespace caffe2 { bool tryConvertToOpenGL(const NetDef& initNet, diff --git a/caffe2/onnx/backend_rep.h b/caffe2/onnx/backend_rep.h index 5fe503bbe7ad98..fb46d19d10ba43 100644 --- a/caffe2/onnx/backend_rep.h +++ b/caffe2/onnx/backend_rep.h @@ -1,6 +1,6 @@ #pragma once -#include "caffe2/core/predictor.h" +#include "caffe2/predictor/predictor.h" #include "caffe2/proto/caffe2.pb.h" #include diff --git a/caffe2/predictor/CMakeLists.txt b/caffe2/predictor/CMakeLists.txt new file mode 100644 index 00000000000000..1038a84af38da5 --- /dev/null +++ b/caffe2/predictor/CMakeLists.txt @@ -0,0 +1,13 @@ +set(Caffe2_PREDICTOR_CPU_SRC + "${CMAKE_CURRENT_SOURCE_DIR}/predictor.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/predictor_utils.cc" +) +set(Caffe2_PREDICTOR_CPU_TEST_SRC + "${CMAKE_CURRENT_SOURCE_DIR}/predictor_test.cc") + +# Common files that are always going to be included. +list(APPEND Caffe2_CPU_SRCS ${Caffe2_PREDICTOR_CPU_SRC}) +list(APPEND Caffe2_CPU_TEST_SRCS ${Caffe2_PREDICTOR_CPU_TEST_SRC}) + +set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE) +set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE) diff --git a/caffe2/core/predictor.cc b/caffe2/predictor/predictor.cc similarity index 99% rename from caffe2/core/predictor.cc rename to caffe2/predictor/predictor.cc index 2aaa7a2dac3a30..8c3001571d2f3c 100644 --- a/caffe2/core/predictor.cc +++ b/caffe2/predictor/predictor.cc @@ -1,4 +1,4 @@ -#include "caffe2/core/predictor.h" +#include "caffe2/predictor/predictor.h" #ifdef CAFFE2_OPTIMIZER #include "caffe2/opt/optimizer.h" #endif diff --git a/caffe2/core/predictor.h b/caffe2/predictor/predictor.h similarity index 97% rename from caffe2/core/predictor.h rename to caffe2/predictor/predictor.h index b56401a35da5c3..a3f05d7aacac89 100644 --- a/caffe2/core/predictor.h +++ b/caffe2/predictor/predictor.h @@ -2,8 +2,8 @@ #include #include "caffe2/core/net.h" -#include "caffe2/core/predictor_config.h" #include "caffe2/core/tensor.h" +#include "caffe2/predictor/predictor_config.h" #include "caffe2/proto/metanet.pb.h" #include "caffe2/proto/predictor_consts.pb.h" diff --git a/caffe2/core/predictor_config.h b/caffe2/predictor/predictor_config.h similarity index 100% rename from caffe2/core/predictor_config.h rename to caffe2/predictor/predictor_config.h diff --git a/caffe2/core/predictor_test.cc b/caffe2/predictor/predictor_test.cc similarity index 99% rename from caffe2/core/predictor_test.cc rename to caffe2/predictor/predictor_test.cc index a37dbbb9e8d39e..31a102aea8712b 100644 --- a/caffe2/core/predictor_test.cc +++ b/caffe2/predictor/predictor_test.cc @@ -1,7 +1,7 @@ #include "caffe2/core/context.h" #include "caffe2/core/operator.h" -#include "caffe2/core/predictor.h" #include "caffe2/core/tensor.h" +#include "caffe2/predictor/predictor.h" #include "caffe2/utils/math.h" #include diff --git a/caffe2/core/predictor_utils.cc b/caffe2/predictor/predictor_utils.cc similarity index 98% rename from caffe2/core/predictor_utils.cc rename to caffe2/predictor/predictor_utils.cc index dea0388fc12528..cc37eec85fbaa1 100644 --- a/caffe2/core/predictor_utils.cc +++ b/caffe2/predictor/predictor_utils.cc @@ -1,4 +1,4 @@ -#include "caffe2/core/predictor_utils.h" +#include "caffe2/predictor/predictor_utils.h" #include "caffe2/core/blob.h" #include "caffe2/core/logging.h" diff --git a/caffe2/core/predictor_utils.h b/caffe2/predictor/predictor_utils.h similarity index 100% rename from caffe2/core/predictor_utils.h rename to caffe2/predictor/predictor_utils.h diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index 04df247d821daf..9256896bd8d138 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -12,7 +12,6 @@ #include "caffe2/core/db.h" #include "caffe2/core/numa.h" #include "caffe2/core/operator.h" -#include "caffe2/core/predictor.h" #include "caffe2/core/stats.h" #include "caffe2/core/transform.h" #include "caffe2/mkl/mkl_utils.h" @@ -28,6 +27,7 @@ #include "caffe2/opt/optimize_ideep.h" #include "caffe2/opt/passes.h" #include "caffe2/opt/sink.h" +#include "caffe2/predictor/predictor.h" #include "caffe2/utils/cpuid.h" #include "caffe2/utils/string_utils.h" From a841006353b3b460965832d91443a8e982c12e0f Mon Sep 17 00:00:00 2001 From: Owen Anderson Date: Thu, 26 Jul 2018 19:49:05 -0700 Subject: [PATCH 09/17] Simplify some code by directly constructing unordered_set from nodes. Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9675 Differential Revision: D8952196 Pulled By: resistor fbshipit-source-id: 5ef2308fed9f702021f650cf2d241a83d880d359 --- torch/csrc/DynamicTypes.cpp | 3 +-- torch/csrc/jit/passes/create_autodiff_subgraphs.cpp | 5 +---- torch/csrc/jit/script/compiler.cpp | 5 +++-- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp index b9dfa25d9ee870..f2165f4efa6d83 100644 --- a/torch/csrc/DynamicTypes.cpp +++ b/torch/csrc/DynamicTypes.cpp @@ -140,8 +140,7 @@ PyObject* createPyObject(const at::Storage& storage) bool isStorage(PyObject* obj) { - auto it = py_storage_type_to_attype.find(Py_TYPE(obj)); - return it != py_storage_type_to_attype.end(); + return py_storage_type_to_attype.count(Py_TYPE(obj)); } std::unique_ptr createStorage(PyObject* obj) { diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp index 28b1195efd5273..d37ff6dfea5b43 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp @@ -35,10 +35,7 @@ void mergeNodes(Block * block, Symbol group_node_kind, ArrayRef nodes) { value_map[v] = nv; return nv; }; - std::unordered_set group_set; - for(auto n : nodes) { - group_set.insert(n); - } + std::unordered_set group_set(nodes.begin(), nodes.end()); for(auto n : nodes) { auto nn = new_graph->appendNode(new_graph->createClone(n, getOrCreateInput)); for(size_t i = 0; i < nn->outputs().size(); ++i) { diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 48f83881356c58..aa39746d779ec3 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -149,8 +149,9 @@ struct Environment { std::shared_ptr next; SugaredValuePtr findInThisFrame(const std::string& name) { - if (value_table.count(name)) { - return value_table.at(name); + auto it = value_table.find(name); + if (it != value_table.end()) { + return it->second; } return nullptr; } From e41eb4332783ce32d26fdf3418688a583a5cc478 Mon Sep 17 00:00:00 2001 From: Vishwak Srinivasan Date: Thu, 26 Jul 2018 20:48:55 -0700 Subject: [PATCH 10/17] Remove deprecated masked_copy (#9819) Summary: No tests are affected by this removal. Closes https://github.com/pytorch/pytorch/issues/1885 and closes #9817 While I was at it, I also fixed #9876 . Pull Request resolved: https://github.com/pytorch/pytorch/pull/9819 Differential Revision: D9018126 Pulled By: SsnL fbshipit-source-id: a9142bf4e2403bef05779a097f61fa8b7db04b71 --- docs/source/tensors.rst | 1 + docs/source/torch.rst | 1 + torch/tensor.py | 8 -------- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 05909a692b2a5b..c3c85797b4cd82 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -224,6 +224,7 @@ view of a storage and defines numeric operations on it. .. automethod:: expand_as .. automethod:: exponential_ .. automethod:: fill_ + .. automethod:: flatten .. automethod:: flip .. automethod:: float .. automethod:: floor diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 3ee7d6e7abe68c..c1e914c03c74e7 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -259,6 +259,7 @@ Other Operations .. autofunction:: diagflat .. autofunction:: diagonal .. autofunction:: einsum +.. autofunction:: flatten .. autofunction:: flip .. autofunction:: histc .. autofunction:: meshgrid diff --git a/torch/tensor.py b/torch/tensor.py index 60a50b6b67b454..b9a35e39ae6952 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -292,14 +292,6 @@ def scatter(self, dim, index, source): def scatter_add(self, dim, index, source): return self.clone().scatter_add_(dim, index, source) - def masked_copy(self, mask, tensor): - warnings.warn("masked_copy is deprecated and renamed to masked_scatter, and will be removed in v0.3") - return self.masked_scatter(mask, tensor) - - def masked_copy_(self, mask, tensor): - warnings.warn("masked_copy_ is deprecated and renamed to masked_scatter_, and will be removed in v0.3") - return self.masked_scatter_(mask, tensor) - def masked_scatter(self, mask, tensor): return self.clone().masked_scatter_(mask, tensor) From eb338878162c9306f2fa473ad4a365733b5356ec Mon Sep 17 00:00:00 2001 From: David Brownell Date: Thu, 26 Jul 2018 21:00:06 -0700 Subject: [PATCH 11/17] =?UTF-8?q?Addressed=20issue=20identified=20by=20sta?= =?UTF-8?q?tic=20code=20analysis:=20potential=20buffer=20=E2=80=A6=20(#988?= =?UTF-8?q?9)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: …overrun Pull Request resolved: https://github.com/pytorch/pytorch/pull/9889 Differential Revision: D9026278 Pulled By: soumith fbshipit-source-id: ee2ee255f34731ddc581261984c3caf56faa0e12 --- torch/csrc/Exceptions.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/csrc/Exceptions.cpp b/torch/csrc/Exceptions.cpp index 306dd3a70a4228..4de8a71591e406 100644 --- a/torch/csrc/Exceptions.cpp +++ b/torch/csrc/Exceptions.cpp @@ -109,6 +109,10 @@ static std::string formatMessage(const char *format, va_list fmt_args) { static const size_t ERROR_BUF_SIZE = 1024; char error_buf[ERROR_BUF_SIZE]; vsnprintf(error_buf, ERROR_BUF_SIZE, format, fmt_args); + + // Ensure that the string is null terminated + error_buf[sizeof(error_buf) / sizeof(*error_buf) - 1] = 0; + return std::string(error_buf); } From aa671ddefacee4abb471e851d5214f508d64f235 Mon Sep 17 00:00:00 2001 From: James Sun Date: Thu, 26 Jul 2018 21:34:59 -0700 Subject: [PATCH 12/17] Support production models with predictor benchmark (#9855) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9855 Support production models with predictor benchmark Two new flags are added: `--update_prod`: pull production data (netdef, input types, input dims) from Hive and store locally `--use_prod`: run benchmark with local production data with the same workload as in production. By default, 300 models will be loaded. production vs benchmark avg net run time: (collected by prod: https://fburl.com/scuba/6lb91zfx and bench: https://fburl.com/ngjj1dc8) **prod: `408us` vs bench: `543us`** (With prod data distribution, this should be even closer) framework overhead (as of 2018-07-22): prod: ``` 9.111% BlackBoxPredictor::Run 4.602% SimpleNet::Run 2.377% Operator::Run 1.786% BlackBoxPredictor::AllocateMemory 1.372% Observable::StartAllObservers 1.358% Observable::StartObserver 1.206% Blob::GetMutable ``` bench: ``` 8.577% BlackBoxPredictor::operator() 3.276% SimpleNet::Run 1.954% Operator::Run 1.697% BlackBoxPredictor::AllocateMemory 1.477% Tensor::ShareData 1.230% Blob::GetMutable 1.034% Observable::StartObserver ``` Reviewed By: yinghai Differential Revision: D8942996 fbshipit-source-id: 27355d7bb5a9fd8d0a40195261d13a97fa24ce17 --- caffe2/core/operator.h | 12 +++++++----- .../lengths_reducer_fused_8bit_rowwise_ops.h | 16 +++++++++++++++- caffe2/operators/lengths_reducer_ops.h | 16 +++++++++++++++- caffe2/operators/one_hot_ops.cc | 3 +++ caffe2/operators/one_hot_ops.h | 8 ++++++++ caffe2/operators/slice_op.h | 3 +++ 6 files changed, 51 insertions(+), 7 deletions(-) diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index 325ccd3761afb3..757019ef64f07a 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -533,10 +533,11 @@ class Operator : public OperatorBase { return fillers; } -#define DISABLE_INPUT_FILLERS(Context) \ - std::vector> InputFillers( \ - const std::vector>& /* unused */) override { \ - throw UnsupportedOperatorFeature("Op does not have input fillers"); \ +#define DISABLE_INPUT_FILLERS(Context) \ + std::vector> InputFillers( \ + const std::vector>& /* unused */) override { \ + throw UnsupportedOperatorFeature( \ + OperatorBase::type() + " does not have input fillers"); \ } void SparseLengthsFillerHelper( @@ -554,7 +555,8 @@ class Operator : public OperatorBase { size_t segment_index, std::vector>* fillers) { CAFFE_ENFORCE_EQ(shapes[segment_index].size(), 1); - // TODO: what would be a proper #segments + // TODO (mnaumov): distribution of value + (*fillers)[value_index].Min(0).Max(shapes[value_index].front() * 2); (*fillers)[segment_index].SparseSegments(shapes[value_index].front() - 1); } diff --git a/caffe2/operators/lengths_reducer_fused_8bit_rowwise_ops.h b/caffe2/operators/lengths_reducer_fused_8bit_rowwise_ops.h index 7c42d522f2e71f..198e4d81f772a3 100644 --- a/caffe2/operators/lengths_reducer_fused_8bit_rowwise_ops.h +++ b/caffe2/operators/lengths_reducer_fused_8bit_rowwise_ops.h @@ -68,7 +68,21 @@ class SparseLengthsFused8BitRowwiseOp : public Operator { return true; } - USE_VALUE_KEY_LENGTH_INPUT_FILLERS(Context, DATA, INDICES, LENGTHS) + std::vector> InputFillers( + const std::vector>& shapes) override { + CAFFE_ENFORCE_EQ(shapes.size(), Operator::Inputs().size()); + auto fillers = Operator::InputFillers(shapes); + if (with_weights) { + // TODO: enable the fillers + throw UnsupportedOperatorFeature( + OperatorBase::type() + " does not have input fillers"); + } + Operator::SparseLengthsFillerHelper( + shapes, INDICES, LENGTHS, &fillers); + Operator::SparseSegmentsFillerHelper( + shapes, DATA, INDICES, &fillers); + return fillers; + } private: enum { diff --git a/caffe2/operators/lengths_reducer_ops.h b/caffe2/operators/lengths_reducer_ops.h index 505dad1b102de3..f96c379ae81624 100644 --- a/caffe2/operators/lengths_reducer_ops.h +++ b/caffe2/operators/lengths_reducer_ops.h @@ -92,7 +92,21 @@ class CPUSparseLengthsReductionOp : public Operator { return true; } - USE_VALUE_KEY_LENGTH_INPUT_FILLERS(CPUContext, DATA, INDICES, LENGTHS) + std::vector> InputFillers( + const std::vector>& shapes) override { + CAFFE_ENFORCE_EQ(shapes.size(), Operator::Inputs().size()); + auto fillers = Operator::InputFillers(shapes); + if (USE_WEIGHT) { + // TODO: enable the fillers + throw UnsupportedOperatorFeature( + OperatorBase::type() + " does not have input fillers"); + } + Operator::SparseLengthsFillerHelper( + shapes, INDICES, LENGTHS, &fillers); + Operator::SparseSegmentsFillerHelper( + shapes, DATA, INDICES, &fillers); + return fillers; + } private: enum { diff --git a/caffe2/operators/one_hot_ops.cc b/caffe2/operators/one_hot_ops.cc index bb8a1dbc774413..4465e1c744044f 100644 --- a/caffe2/operators/one_hot_ops.cc +++ b/caffe2/operators/one_hot_ops.cc @@ -172,6 +172,9 @@ class SegmentOneHotOp : public Operator { SegmentOneHotOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws) {} + // TODO: enable input filler + DISABLE_INPUT_FILLERS(CPUContext) + bool RunOnDevice() override { auto& lengths = Input(0); auto& indices = Input(1); diff --git a/caffe2/operators/one_hot_ops.h b/caffe2/operators/one_hot_ops.h index 1b48b69326f3e7..644b3e74dd978f 100644 --- a/caffe2/operators/one_hot_ops.h +++ b/caffe2/operators/one_hot_ops.h @@ -13,6 +13,9 @@ class OneHotOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; + // TODO: enable input filler + DISABLE_INPUT_FILLERS(Context) + OneHotOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws) {} @@ -58,6 +61,8 @@ class BatchOneHotOp final : public Operator { BatchOneHotOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws) {} + USE_VALUE_KEY_LENGTH_INPUT_FILLERS(Context, X, VALS, LENS) + bool RunOnDevice() override { return DispatchHelper>::call(this, Input(X)); } @@ -83,6 +88,9 @@ class BatchBucketOneHotOp final : public Operator { bool RunOnDevice() override; + // TODO: enable input filler + DISABLE_INPUT_FILLERS(Context) + protected: INPUT_TAGS(X, LENS, BOUNDARIES); OUTPUT_TAGS(ONE_HOT); diff --git a/caffe2/operators/slice_op.h b/caffe2/operators/slice_op.h index 12734a8e33df71..01eed59598a87c 100644 --- a/caffe2/operators/slice_op.h +++ b/caffe2/operators/slice_op.h @@ -212,6 +212,9 @@ class SliceOp : public Operator { return RunOnDeviceImpl(Input(0), Output(0)); } + // This cannot be enabled given the output dims depends on the input + DISABLE_INPUT_FILLERS(Context) + protected: bool RunOnDeviceImpl(const Tensor& data, Tensor* output) { if (InputSize() > 1) { From 2c1d9e09b89e3094c24eb770f35db09a45b4a877 Mon Sep 17 00:00:00 2001 From: Vignesh Ramanathan Date: Thu, 26 Jul 2018 21:59:42 -0700 Subject: [PATCH 13/17] Support UINT8 for addition data in ImageInputOp (#9901) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9901 Added support for UINT8 datatype for additional data (prefetching and output) by ImageInputOp Reviewed By: ashwinb Differential Revision: D9018964 fbshipit-source-id: f938a8a072c15c0ee521b2f16788c024b08cd37f --- caffe2/image/image_input_op.h | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/caffe2/image/image_input_op.h b/caffe2/image/image_input_op.h index a8c45ca87d46a1..6bf232977d92f2 100644 --- a/caffe2/image/image_input_op.h +++ b/caffe2/image/image_input_op.h @@ -658,8 +658,16 @@ bool ImageInputOp::GetImageAndLabelAndInfoFromDBValue( for (int j = 0; j < additional_output_proto.int64_data_size(); ++j) { additional_output[j] = additional_output_proto.int64_data(j); } - } - else { + } else if (additional_output_proto.data_type() == TensorProto::UINT8) { + uint8_t* additional_output = + prefetched_additional_outputs_[i].template mutable_data() + + item_id * additional_output_proto.int32_data_size(); + + for (int j = 0; j < additional_output_proto.int32_data_size(); ++j) { + additional_output[j] = + static_cast(additional_output_proto.int32_data(j)); + } + } else { LOG(FATAL) << "Unsupported output type."; } } @@ -1148,6 +1156,9 @@ bool ImageInputOp::Prefetch() { } else if ( additional_output_proto.data_type() == TensorProto::INT64) { prefetched_additional_outputs_[i].template mutable_data(); + } else if ( + additional_output_proto.data_type() == TensorProto::UINT8) { + prefetched_additional_outputs_[i].template mutable_data(); } else { LOG(FATAL) << "Unsupported output type."; } From 8cb1eef7b9e512b3f2b4f89a5ca58f171fdde87f Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 26 Jul 2018 22:03:44 -0700 Subject: [PATCH 14/17] Unify IR operator representation (stop using attributes in the JIT) (#9807) Summary: Based on top of #9763 (first 3 commits belong to that PR). The first commits from this PR are "Stop using attributes ..." I tried to separate the changes into fairly meaningful commits. I can't split them up into smaller PRs, because everything starts working and all tests pass only after the whole sequence, but hopefully this will make reviewing somewhat easier. Known issues/regressions/future tasks: - `aten::lerp` and `aten::clamp` are no longer fusable - `CreateAutodiffSubgraphs` needs a rewrite - It is much more strict now, and will miss a lot of opportunities, especially when viewing ops are involved. Our previous approach was "ignore the assumption on shape availability in gradient formulas to determine differentiability, and hope that shape prop will be robust enough to actually deliver them before we differentiate", which obviously doesn't scale well to more complex cases. We should either work on reducing the size dependency of grad formulas (feasible e.g. for `view`/`reshape`, unfeasible for `squeeze`/`unsqueeze`), or make `CreateAutodiffSubgraphs` integrate some kind of "I could integrate this node into an AD subgraph, but will I be able to infer the shape of its input" reasoning (kind of like a limited shape prop, that doesn't infer anything, and only tells if it *could* infer something). - It sometimes creates constant-only (or constants + one node) graphs, which is useless - Broken `aten::add` in auto-batching, because it gained a non-tensor input. I changed the test for pointwise operations to use `aten::mul` instead, but I needed to disable the LSTM cell test. I'm not sure how scalar constants should be implemented in this case, because I don't fully understand our format. cc: ChunliF - Graph import does some hacks to recover type of constants. This code should be removed once we'll gain the ability to export the IR along with value types. - There's still a fair amount of dead code that can be removed. I didn't want to make this diff any bigger, and removing it is an easy task. - Graph fuser could be improved to use signature matching (possibly using `OperatorSet`) instead of basing on node kinds. - Manual constant propagation for the `ListConstruct` node in `torch/onnx/utils.py` should be replaced with a proper constant propagation pass (or we should ensure that the one we have handles at least this case before we remove this code). zdevito Pull Request resolved: https://github.com/pytorch/pytorch/pull/9807 Reviewed By: ezyang Differential Revision: D9004285 Pulled By: apaszke fbshipit-source-id: fe88026a765f6b687354add034c86402362508b7 --- test/expect/TestJit.test_alexnet.expect | 81 ++-- test/expect/TestJit.test_batchnorm.expect | 8 +- test/expect/TestJit.test_concat_fusion.expect | 14 +- test/expect/TestJit.test_conv.expect | 21 +- test/expect/TestJit.test_cpp.expect | 166 ++++---- test/expect/TestJit.test_cse.expect | 15 +- .../TestJit.test_decompose_addmm.expect | 31 +- .../TestJit.test_fuse_last_device.expect | 14 +- .../TestJit.test_fusion_distribute.expect | 22 +- .../TestJit.test_inplace_transplant.expect | 10 +- .../TestJit.test_lstm_fusion_concat.expect | 60 +-- .../TestJit.test_lstm_fusion_cuda.expect | 57 +-- .../expect/TestJit.test_nested_inplace.expect | 6 +- test/expect/TestJit.test_python_ir.expect | 13 +- .../expect/TestJit.test_repeated_input.expect | 5 +- .../TestJit.test_repeated_output.expect | 5 +- test/expect/TestJit.test_scopes.expect | 11 +- ...stJit.test_shape_analysis_broadcast.expect | 13 +- test/expect/TestJit.test_shared_param.expect | 5 +- test/expect/TestJit.test_simple.expect | 11 +- test/expect/TestJit.test_trace_size.expect | 21 +- .../TestJit.test_trace_size_with_grad.expect | 21 +- ....test_call_python_fn_from_script_fn.expect | 4 +- ...test_call_python_fn_from_tracing_fn.expect | 6 +- ...test_call_python_mod_from_script_fn.expect | 4 +- ..._call_python_mod_from_traced_module.expect | 6 +- ...est_call_python_mod_from_tracing_fn.expect | 6 +- ....test_call_script_fn_from_script_fn.expect | 4 +- ...test_call_script_fn_from_tracing_fn.expect | 6 +- ...test_call_script_mod_from_script_fn.expect | 4 +- ...est_call_script_mod_from_tracing_fn.expect | 6 +- ...ll_script_module_from_traced_module.expect | 6 +- ....test_call_traced_fn_from_script_fn.expect | 4 +- ...test_call_traced_fn_from_tracing_fn.expect | 6 +- ...test_call_traced_mod_from_script_fn.expect | 4 +- ...est_call_traced_mod_from_tracing_fn.expect | 6 +- ...ll_traced_module_from_traced_module.expect | 6 +- test/expect/TestScript.test_cat_lifts.expect | 15 +- ...ript.test_index_put_trace_with_view.expect | 11 +- ...t.test_index_put_trace_without_view.expect | 7 +- ...Script.test_index_select_shape_prop.expect | 5 +- ...ipt.test_loop_unroll_unused_counter.expect | 46 ++- .../TestScript.test_loop_unrolling.expect | 18 +- ...test_loop_unrolling_const-add_const.expect | 40 +- ....test_loop_unrolling_const-add_iter.expect | 20 +- ...stScript.test_loop_unrolling_nested.expect | 18 +- .../expect/TestScript.test_math_schema.expect | 5 +- .../TestScript.test_math_tensor_number.expect | 6 +- ...stOperators.test_batchnorm_training.expect | 4 +- test/test_jit.py | 14 +- tools/autograd/gen_variable_type.py | 57 +-- tools/autograd/templates/VariableType.cpp | 61 --- tools/jit/gen_jit_dispatch.py | 2 +- torch/csrc/jit/autodiff.cpp | 383 +++++++++++------- torch/csrc/jit/fusion_compiler.cpp | 62 ++- torch/csrc/jit/graph_executor.cpp | 4 +- torch/csrc/jit/import.cpp | 40 ++ torch/csrc/jit/interpreter.cpp | 1 + torch/csrc/jit/ir.cpp | 18 +- torch/csrc/jit/operator.cpp | 30 +- torch/csrc/jit/operator.h | 10 + .../passes/common_subexpression_elimination.h | 1 + torch/csrc/jit/passes/graph_fuser.cpp | 162 ++++---- torch/csrc/jit/passes/onnx.cpp | 2 + torch/csrc/jit/passes/onnx/peephole.cpp | 23 +- torch/csrc/jit/passes/peephole.cpp | 2 +- torch/csrc/jit/passes/remove_expands.cpp | 2 +- torch/csrc/jit/passes/shape_analysis.cpp | 37 +- torch/csrc/jit/passes/shape_analysis.h | 3 + torch/csrc/jit/python_ir.cpp | 7 +- torch/csrc/jit/python_tracer.cpp | 25 +- torch/csrc/jit/script/compiler.cpp | 79 +--- torch/csrc/jit/symbolic_variable.h | 159 +++----- torch/csrc/jit/tracer.cpp | 52 ++- torch/csrc/jit/tracer.h | 48 ++- torch/onnx/symbolic.py | 188 ++++++--- torch/onnx/utils.py | 33 +- 77 files changed, 1369 insertions(+), 1019 deletions(-) diff --git a/test/expect/TestJit.test_alexnet.expect b/test/expect/TestJit.test_alexnet.expect index 9a1105c8b8b176..09d274f58be622 100644 --- a/test/expect/TestJit.test_alexnet.expect +++ b/test/expect/TestJit.test_alexnet.expect @@ -15,38 +15,51 @@ graph(%0 : Double(1, 3, 224, 224) %14 : Double(4096) %15 : Double(1000, 4096) %16 : Double(1000)) { - %17 : Double(1, 64, 55, 55) = aten::_convolution[stride=[4, 4], padding=[2, 2], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%0, %1, %2), scope: AlexNet/Sequential[features]/Conv2d[0] - %18 : Double(1, 64, 55, 55) = aten::threshold[threshold={0}, value={0}](%17), scope: AlexNet/Sequential[features]/ReLU[1] - %19 : Double(1, 64, 27, 27), %20 : Long(1, 64, 27, 27) = aten::max_pool2d_with_indices[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2] - %21 : Double(1, 192, 27, 27) = aten::_convolution[stride=[1, 1], padding=[2, 2], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%19, %3, %4), scope: AlexNet/Sequential[features]/Conv2d[3] - %22 : Double(1, 192, 27, 27) = aten::threshold[threshold={0}, value={0}](%21), scope: AlexNet/Sequential[features]/ReLU[4] - %23 : Double(1, 192, 13, 13), %24 : Long(1, 192, 13, 13) = aten::max_pool2d_with_indices[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%22), scope: AlexNet/Sequential[features]/MaxPool2d[5] - %25 : Double(1, 384, 13, 13) = aten::_convolution[stride=[1, 1], padding=[1, 1], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%23, %5, %6), scope: AlexNet/Sequential[features]/Conv2d[6] - %26 : Double(1, 384, 13, 13) = aten::threshold[threshold={0}, value={0}](%25), scope: AlexNet/Sequential[features]/ReLU[7] - %27 : Double(1, 256, 13, 13) = aten::_convolution[stride=[1, 1], padding=[1, 1], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%26, %7, %8), scope: AlexNet/Sequential[features]/Conv2d[8] - %28 : Double(1, 256, 13, 13) = aten::threshold[threshold={0}, value={0}](%27), scope: AlexNet/Sequential[features]/ReLU[9] - %29 : Double(1, 256, 13, 13) = aten::_convolution[stride=[1, 1], padding=[1, 1], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%28, %9, %10), scope: AlexNet/Sequential[features]/Conv2d[10] - %30 : Double(1, 256, 13, 13) = aten::threshold[threshold={0}, value={0}](%29), scope: AlexNet/Sequential[features]/ReLU[11] - %31 : Double(1, 256, 6, 6), %32 : Long(1, 256, 6, 6) = aten::max_pool2d_with_indices[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%30), scope: AlexNet/Sequential[features]/MaxPool2d[12] - %33 : int = prim::Constant[value=0](), scope: AlexNet - %34 : int = aten::size(%31, %33), scope: AlexNet - %35 : Long() = prim::NumToTensor(%34), scope: AlexNet - %36 : int = prim::TensorToNum(%35), scope: AlexNet - %37 : int = prim::Constant[value=9216](), scope: AlexNet - %38 : int[] = prim::ListConstruct(%36, %37), scope: AlexNet - %39 : Double(1, 9216) = aten::view(%31, %38), scope: AlexNet - %40 : Double(1, 9216) = ^Dropout(0.5, True, False)(%39), scope: AlexNet/Sequential[classifier]/Dropout[0] - %41 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1] - %42 : Double(1, 4096) = aten::expand[size=[1, 4096], implicit=1](%12), scope: AlexNet/Sequential[classifier]/Linear[1] - %43 : Double(1, 4096) = aten::addmm[beta={1}, alpha={1}](%42, %40, %41), scope: AlexNet/Sequential[classifier]/Linear[1] - %44 : Double(1, 4096) = aten::threshold[threshold={0}, value={0}](%43), scope: AlexNet/Sequential[classifier]/ReLU[2] - %45 : Double(1, 4096) = ^Dropout(0.5, True, False)(%44), scope: AlexNet/Sequential[classifier]/Dropout[3] - %46 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4] - %47 : Double(1, 4096) = aten::expand[size=[1, 4096], implicit=1](%14), scope: AlexNet/Sequential[classifier]/Linear[4] - %48 : Double(1, 4096) = aten::addmm[beta={1}, alpha={1}](%47, %45, %46), scope: AlexNet/Sequential[classifier]/Linear[4] - %49 : Double(1, 4096) = aten::threshold[threshold={0}, value={0}](%48), scope: AlexNet/Sequential[classifier]/ReLU[5] - %50 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6] - %51 : Double(1, 1000) = aten::expand[size=[1, 1000], implicit=1](%16), scope: AlexNet/Sequential[classifier]/Linear[6] - %52 : Double(1, 1000) = aten::addmm[beta={1}, alpha={1}](%51, %49, %50), scope: AlexNet/Sequential[classifier]/Linear[6] - return (%52); + %17 : int = prim::Constant[value=4](), scope: AlexNet/Sequential[features]/Conv2d[0] + %18 : int[] = prim::ListConstruct(%17, %17), scope: AlexNet/Sequential[features]/Conv2d[0] + %19 : int = prim::Constant[value=2](), scope: AlexNet/Sequential[features]/Conv2d[0] + %20 : int[] = prim::ListConstruct(%19, %19), scope: AlexNet/Sequential[features]/Conv2d[0] + %21 : int = prim::Constant[value=1](), scope: AlexNet/Sequential[features]/Conv2d[0] + %22 : int[] = prim::ListConstruct(%21, %21), scope: AlexNet/Sequential[features]/Conv2d[0] + %23 : int = prim::Constant[value=0](), scope: AlexNet/Sequential[features]/Conv2d[0] + %24 : int[] = prim::ListConstruct(%23, %23), scope: AlexNet/Sequential[features]/Conv2d[0] + %25 : Double(1, 64, 55, 55) = aten::_convolution(%0, %1, %2, %18, %20, %22, %23, %24, %21, %23, %23, %21), scope: AlexNet/Sequential[features]/Conv2d[0] + %26 : Double(1, 64, 55, 55) = aten::threshold(%25, %23, %23), scope: AlexNet/Sequential[features]/ReLU[1] + %27 : int = prim::Constant[value=3](), scope: AlexNet/Sequential[features]/MaxPool2d[2] + %28 : int[] = prim::ListConstruct(%27, %27), scope: AlexNet/Sequential[features]/MaxPool2d[2] + %29 : Double(1, 64, 27, 27), %30 : Long(1, 64, 27, 27) = aten::max_pool2d_with_indices(%26, %28, %20, %24, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[2] + %31 : Double(1, 192, 27, 27) = aten::_convolution(%29, %3, %4, %22, %20, %22, %23, %24, %21, %23, %23, %21), scope: AlexNet/Sequential[features]/Conv2d[3] + %32 : Double(1, 192, 27, 27) = aten::threshold(%31, %23, %23), scope: AlexNet/Sequential[features]/ReLU[4] + %33 : Double(1, 192, 13, 13), %34 : Long(1, 192, 13, 13) = aten::max_pool2d_with_indices(%32, %28, %20, %24, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[5] + %35 : Double(1, 384, 13, 13) = aten::_convolution(%33, %5, %6, %22, %22, %22, %23, %24, %21, %23, %23, %21), scope: AlexNet/Sequential[features]/Conv2d[6] + %36 : Double(1, 384, 13, 13) = aten::threshold(%35, %23, %23), scope: AlexNet/Sequential[features]/ReLU[7] + %37 : Double(1, 256, 13, 13) = aten::_convolution(%36, %7, %8, %22, %22, %22, %23, %24, %21, %23, %23, %21), scope: AlexNet/Sequential[features]/Conv2d[8] + %38 : Double(1, 256, 13, 13) = aten::threshold(%37, %23, %23), scope: AlexNet/Sequential[features]/ReLU[9] + %39 : Double(1, 256, 13, 13) = aten::_convolution(%38, %9, %10, %22, %22, %22, %23, %24, %21, %23, %23, %21), scope: AlexNet/Sequential[features]/Conv2d[10] + %40 : Double(1, 256, 13, 13) = aten::threshold(%39, %23, %23), scope: AlexNet/Sequential[features]/ReLU[11] + %41 : Double(1, 256, 6, 6), %42 : Long(1, 256, 6, 6) = aten::max_pool2d_with_indices(%40, %28, %20, %24, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[12] + %43 : int = aten::size(%41, %23), scope: AlexNet + %44 : Long() = prim::NumToTensor(%43), scope: AlexNet + %45 : int = prim::TensorToNum(%44), scope: AlexNet + %46 : int = prim::Constant[value=9216](), scope: AlexNet + %47 : int[] = prim::ListConstruct(%45, %46), scope: AlexNet + %48 : Double(1, 9216) = aten::view(%41, %47), scope: AlexNet + %49 : Double(1, 9216) = ^Dropout(0.5, True, False)(%48), scope: AlexNet/Sequential[classifier]/Dropout[0] + %50 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1] + %51 : int = prim::Constant[value=4096](), scope: AlexNet/Sequential[classifier]/Linear[1] + %52 : int[] = prim::ListConstruct(%21, %51), scope: AlexNet/Sequential[classifier]/Linear[1] + %53 : Double(1, 4096) = aten::expand(%12, %52, %21), scope: AlexNet/Sequential[classifier]/Linear[1] + %54 : Double(1, 4096) = aten::addmm(%53, %49, %50, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[1] + %55 : Double(1, 4096) = aten::threshold(%54, %23, %23), scope: AlexNet/Sequential[classifier]/ReLU[2] + %56 : Double(1, 4096) = ^Dropout(0.5, True, False)(%55), scope: AlexNet/Sequential[classifier]/Dropout[3] + %57 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4] + %58 : Double(1, 4096) = aten::expand(%14, %52, %21), scope: AlexNet/Sequential[classifier]/Linear[4] + %59 : Double(1, 4096) = aten::addmm(%58, %56, %57, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[4] + %60 : Double(1, 4096) = aten::threshold(%59, %23, %23), scope: AlexNet/Sequential[classifier]/ReLU[5] + %61 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6] + %62 : int = prim::Constant[value=1000](), scope: AlexNet/Sequential[classifier]/Linear[6] + %63 : int[] = prim::ListConstruct(%21, %62), scope: AlexNet/Sequential[classifier]/Linear[6] + %64 : Double(1, 1000) = aten::expand(%16, %63, %21), scope: AlexNet/Sequential[classifier]/Linear[6] + %65 : Double(1, 1000) = aten::addmm(%64, %60, %61, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[6] + return (%65); } diff --git a/test/expect/TestJit.test_batchnorm.expect b/test/expect/TestJit.test_batchnorm.expect index 4fa8a72a43ae7f..c61390578d45b8 100644 --- a/test/expect/TestJit.test_batchnorm.expect +++ b/test/expect/TestJit.test_batchnorm.expect @@ -4,6 +4,10 @@ graph(%0 : Double(2, 2, 2, 2) %3 : Double(2) %4 : Double(2) %5 : Long()) { - %6 : Double(2, 2, 2, 2) = aten::batch_norm[training=1, momentum=0.1, eps=1e-05, cudnn_enabled=1](%0, %1, %2, %3, %4), scope: BatchNorm2d - return (%6); + %6 : int = prim::Constant[value=1](), scope: BatchNorm2d + %7 : float = prim::Constant[value=0.1](), scope: BatchNorm2d + %8 : float = prim::Constant[value=1e-05](), scope: BatchNorm2d + %9 : int = prim::Constant[value=1](), scope: BatchNorm2d + %10 : Double(2, 2, 2, 2) = aten::batch_norm(%0, %1, %2, %3, %4, %6, %7, %8, %9), scope: BatchNorm2d + return (%10); } diff --git a/test/expect/TestJit.test_concat_fusion.expect b/test/expect/TestJit.test_concat_fusion.expect index c1b45b172745ba..027c2de33e5926 100644 --- a/test/expect/TestJit.test_concat_fusion.expect +++ b/test/expect/TestJit.test_concat_fusion.expect @@ -3,10 +3,12 @@ graph(%0 : Float(3, 20) %2 : Float(6, 20) = prim::FusionGroup_0[device=0](%0, %1) return (%2); } -with prim::FusionGroup_0 = graph(%3 : Float(3, 20) - %4 : Float(3, 20)) { - %6 : Float(3, 20) = aten::add[alpha={1}](%3, %4) - %5 : Float(3, 20) = aten::mul(%3, %4) - %2 : Float(6, 20) = aten::cat[dim=0](%6, %5) - return (%2); +with prim::FusionGroup_0 = graph(%4 : Float(3, 20) + %5 : Float(3, 20)) { + %7 : int = prim::Constant[value=1]() + %8 : Float(3, 20) = aten::add(%4, %5, %7) + %6 : Float(3, 20) = aten::mul(%4, %5) + %2 : int = prim::Constant[value=0]() + %3 : Float(6, 20) = aten::cat(%8, %6, %2) + return (%3); } diff --git a/test/expect/TestJit.test_conv.expect b/test/expect/TestJit.test_conv.expect index 584f807a8ca071..fcb53bad1425dc 100644 --- a/test/expect/TestJit.test_conv.expect +++ b/test/expect/TestJit.test_conv.expect @@ -1,6 +1,23 @@ graph(%0 : Double(20, 16, 50, 40) %1 : Double(13, 16, 3, 3)) { %2 : Dynamic = prim::Undefined(), scope: Conv2d - %3 : Double(20, 13, 48, 38) = aten::_convolution[stride=[1, 1], padding=[0, 0], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%0, %1, %2), scope: Conv2d - return (%3); + %3 : int = prim::Constant[value=1](), scope: Conv2d + %4 : int = prim::Constant[value=1](), scope: Conv2d + %5 : int[] = prim::ListConstruct(%3, %4), scope: Conv2d + %6 : int = prim::Constant[value=0](), scope: Conv2d + %7 : int = prim::Constant[value=0](), scope: Conv2d + %8 : int[] = prim::ListConstruct(%6, %7), scope: Conv2d + %9 : int = prim::Constant[value=1](), scope: Conv2d + %10 : int = prim::Constant[value=1](), scope: Conv2d + %11 : int[] = prim::ListConstruct(%9, %10), scope: Conv2d + %12 : int = prim::Constant[value=0](), scope: Conv2d + %13 : int = prim::Constant[value=0](), scope: Conv2d + %14 : int = prim::Constant[value=0](), scope: Conv2d + %15 : int[] = prim::ListConstruct(%13, %14), scope: Conv2d + %16 : int = prim::Constant[value=1](), scope: Conv2d + %17 : int = prim::Constant[value=0](), scope: Conv2d + %18 : int = prim::Constant[value=0](), scope: Conv2d + %19 : int = prim::Constant[value=1](), scope: Conv2d + %20 : Double(20, 13, 48, 38) = aten::_convolution(%0, %1, %2, %5, %8, %11, %12, %15, %16, %17, %18, %19), scope: Conv2d + return (%20); } diff --git a/test/expect/TestJit.test_cpp.expect b/test/expect/TestJit.test_cpp.expect index bfe49e45cfb618..f1f3a6a9c39012 100644 --- a/test/expect/TestJit.test_cpp.expect +++ b/test/expect/TestJit.test_cpp.expect @@ -2,47 +2,60 @@ testBlocks graph(%a : Dynamic %b : Dynamic %c : Dynamic) { - %2 : Dynamic = aten::add[alpha={1}](%a, %b) - %4 : Dynamic = prim::If(%c) + %2 : int = prim::Constant[value=1]() + %3 : Dynamic = aten::add(%a, %b, %2) + %5 : Dynamic = prim::If(%c) block0() { - %5 : Dynamic = aten::add[alpha={1}](%2, %2) - -> (%5) + %6 : int = prim::Constant[value=1]() + %7 : Dynamic = aten::add(%3, %3, %6) + -> (%7) } block1() { - %6 : Dynamic = aten::add[alpha={1}](%b, %2) - %7 : Dynamic = aten::add[alpha={1}](%6, %2) - -> (%7) + %8 : int = prim::Constant[value=1]() + %9 : Dynamic = aten::add(%b, %3, %8) + %10 : int = prim::Constant[value=1]() + %11 : Dynamic = aten::add(%9, %3, %10) + -> (%11) } - %8 : Dynamic = aten::add[alpha={1}](%4, %2) - return (%8); + %12 : int = prim::Constant[value=1]() + %13 : Dynamic = aten::add(%5, %3, %12) + return (%13); } graph(%a : Dynamic %b : Dynamic %c : Dynamic) { - %2 : Dynamic = aten::add[alpha={1}](%a, %b) - %4 : Dynamic = prim::If(%c) + %2 : int = prim::Constant[value=1]() + %3 : Dynamic = aten::add(%a, %b, %2) + %5 : Dynamic = prim::If(%c) block0() { - %6 : Dynamic = aten::add[alpha={1}](%b, %2) - %7 : Dynamic = aten::add[alpha={1}](%6, %2) - -> (%7) + %8 : int = prim::Constant[value=1]() + %9 : Dynamic = aten::add(%b, %3, %8) + %10 : int = prim::Constant[value=1]() + %11 : Dynamic = aten::add(%9, %3, %10) + -> (%11) } - %8 : Dynamic = aten::add[alpha={1}](%4, %2) - return (%8); + %12 : int = prim::Constant[value=1]() + %13 : Dynamic = aten::add(%5, %3, %12) + return (%13); } graph(%a : Dynamic %b : Dynamic %c : Dynamic) { - %3 : Dynamic = aten::add[alpha={1}](%a, %b) - %4 : Dynamic = prim::If(%c) + %3 : int = prim::Constant[value=1]() + %4 : Dynamic = aten::add(%a, %b, %3) + %5 : Dynamic = prim::If(%c) block0() { - %5 : Dynamic = aten::add[alpha={1}](%b, %3) - %6 : Dynamic = aten::add[alpha={1}](%5, %3) - -> (%6) + %6 : int = prim::Constant[value=1]() + %7 : Dynamic = aten::add(%b, %4, %6) + %8 : int = prim::Constant[value=1]() + %9 : Dynamic = aten::add(%7, %4, %8) + -> (%9) } - %7 : Dynamic = aten::add[alpha={1}](%4, %3) - return (%7); + %10 : int = prim::Constant[value=1]() + %11 : Dynamic = aten::add(%5, %4, %10) + return (%11); } testCreateAutodiffSubgraphs @@ -51,28 +64,32 @@ graph(%0 : Dynamic %2 : Dynamic %3 : Dynamic %4 : Dynamic) { - %21 : Dynamic, %22 : Dynamic = prim::GraphExecutor_0(%0, %3, %1, %4, %2) - return (%22, %21); + %25 : Dynamic, %26 : Dynamic = prim::GraphExecutor_0(%0, %3, %1, %4, %2) + return (%26, %25); } with prim::GraphExecutor_0 = graph(%1 : Dynamic %2 : Dynamic %4 : Dynamic %5 : Dynamic - %16 : Dynamic) { + %19 : Dynamic) { %0 : Dynamic = aten::mm(%1, %2) %3 : Dynamic = aten::mm(%4, %5) - %6 : Dynamic = aten::add[alpha={1}](%0, %3) - %7 : Dynamic, %8 : Dynamic, %9 : Dynamic, %10 : Dynamic = aten::chunk[chunks=4, dim=1](%6) - %11 : Dynamic = aten::sigmoid(%7) - %12 : Dynamic = aten::sigmoid(%10) - %13 : Dynamic = aten::tanh(%9) - %14 : Dynamic = aten::sigmoid(%8) - %15 : Dynamic = aten::mul(%14, %16) - %17 : Dynamic = aten::mul(%11, %13) - %18 : Dynamic = aten::add[alpha={1}](%15, %17) - %19 : Dynamic = aten::tanh(%18) - %20 : Dynamic = aten::mul(%12, %19) - return (%18, %20); + %6 : int = prim::Constant[value=1]() + %7 : Dynamic = aten::add(%0, %3, %6) + %8 : int = prim::Constant[value=4]() + %9 : int = prim::Constant[value=1]() + %10 : Dynamic, %11 : Dynamic, %12 : Dynamic, %13 : Dynamic = aten::chunk(%7, %8, %9) + %14 : Dynamic = aten::sigmoid(%10) + %15 : Dynamic = aten::sigmoid(%13) + %16 : Dynamic = aten::tanh(%12) + %17 : Dynamic = aten::sigmoid(%11) + %18 : Dynamic = aten::mul(%17, %19) + %20 : Dynamic = aten::mul(%14, %16) + %21 : int = prim::Constant[value=1]() + %22 : Dynamic = aten::add(%18, %20, %21) + %23 : Dynamic = aten::tanh(%22) + %24 : Dynamic = aten::mul(%15, %23) + return (%22, %24); } testDifferentiate @@ -80,66 +97,75 @@ graph(%0 : Float(2, 3, 4) %1 : Float(2, 3, 4)) { %2 : Float(2, 3, 4) = aten::mul(%0, %1) %3 : Float(2, 3, 4) = aten::mul(%2, %0) - %4 : Float(2, 3, 4) = aten::add[alpha={1}](%3, %1) - return (%4, %2); + %4 : int = prim::Constant[value=1]() + %5 : Float(2, 3, 4) = aten::add(%3, %1, %4) + return (%5, %2); } graph(%0 : Float(2, 3, 4) %1 : Float(2, 3, 4) %2 : Float(2, 3, 4) %3 : Float(2, 3, 4) %4 : Float(2, 3, 4)) { - %5 : Float(2, 3, 4), %6 : Float(2, 3, 4) = prim::GradOf[name=aten::add](%0) + %5 : int = prim::Constant[value=1]() + %6 : Float(2, 3, 4), %7 : Float(2, 3, 4) = prim::GradOf[name=aten::add](%0) block0() { - -> (%0, %0) + %8 : Float(2, 3, 4) = aten::mul(%0, %5) + -> (%0, %8) } - %7 : Float(2, 3, 4), %8 : Float(2, 3, 4) = prim::GradOf[name=aten::mul](%5) + %9 : Float(2, 3, 4), %10 : Float(2, 3, 4) = prim::GradOf[name=aten::mul](%6) block0() { - %9 : Float(2, 3, 4) = aten::mul(%5, %2) - %10 : Float(2, 3, 4) = aten::mul(%5, %4) - -> (%9, %10) + %11 : Float(2, 3, 4) = aten::mul(%6, %2) + %12 : Float(2, 3, 4) = aten::mul(%6, %4) + -> (%11, %12) } - %11 : Dynamic = prim::AutogradAdd(%1, %7) - %12 : Float(2, 3, 4), %13 : Float(2, 3, 4) = prim::GradOf[name=aten::mul](%11) + %13 : Dynamic = prim::AutogradAdd(%1, %9) + %14 : Float(2, 3, 4), %15 : Float(2, 3, 4) = prim::GradOf[name=aten::mul](%13) block0() { - %14 : Float(2, 3, 4) = aten::mul(%11, %3) - %15 : Float(2, 3, 4) = aten::mul(%11, %2) - -> (%14, %15) + %16 : Float(2, 3, 4) = aten::mul(%13, %3) + %17 : Float(2, 3, 4) = aten::mul(%13, %2) + -> (%16, %17) } - %16 : Dynamic = prim::AutogradAdd(%8, %12) - %17 : Dynamic = prim::AutogradAdd(%6, %13) - return (%16, %17); + %18 : Dynamic = prim::AutogradAdd(%10, %14) + %19 : Dynamic = prim::AutogradAdd(%7, %15) + return (%18, %19); } testDifferentiateWithRequiresGrad graph(%0 : Float(2, 3, 4) %1 : Float(2, 3, 4)) { %2 : Float(2, 3, 4) = aten::mul(%1, %1) - %3 : Float(2, 3, 4) = aten::add[alpha={1}](%2, %1) - %4 : Float(2, 3, 4) = aten::add[alpha={1}](%3, %0) - %5 : Float(2, 3, 4) = aten::mul(%4, %0) - %6 : Float(2, 3, 4) = aten::add[alpha={1}](%5, %1) - return (%3, %6, %4); + %3 : int = prim::Constant[value=1]() + %4 : Float(2, 3, 4) = aten::add(%2, %1, %3) + %5 : int = prim::Constant[value=1]() + %6 : Float(2, 3, 4) = aten::add(%4, %0, %5) + %7 : Float(2, 3, 4) = aten::mul(%6, %0) + %8 : int = prim::Constant[value=1]() + %9 : Float(2, 3, 4) = aten::add(%7, %1, %8) + return (%4, %9, %6); } graph(%0 : Float(2, 3, 4) %1 : Float(2, 3, 4) %2 : Float(2, 3, 4) %3 : Float(2, 3, 4)) { - %4 : Float(2, 3, 4), %5 : Float(2, 3, 4) = prim::GradOf[name=aten::add](%0) + %4 : int = prim::Constant[value=1]() + %5 : Float(2, 3, 4), %6 : Float(2, 3, 4) = prim::GradOf[name=aten::add](%0) block0() { - -> (%0, %0) + %7 : Float(2, 3, 4) = aten::mul(%0, %4) + -> (%0, %7) } - %6 : Float(2, 3, 4), %7 : Float(2, 3, 4) = prim::GradOf[name=aten::mul](%4) + %8 : Float(2, 3, 4), %9 : Float(2, 3, 4) = prim::GradOf[name=aten::mul](%5) block0() { - %8 : Float(2, 3, 4) = aten::mul(%4, %2) - %9 : Float(2, 3, 4) = aten::mul(%4, %3) - -> (%8, %9) + %10 : Float(2, 3, 4) = aten::mul(%5, %2) + %11 : Float(2, 3, 4) = aten::mul(%5, %3) + -> (%10, %11) } - %10 : Dynamic = prim::AutogradAdd(%1, %6) - %11 : Float(2, 3, 4), %12 : Float(2, 3, 4) = prim::GradOf[name=aten::add](%10) + %12 : Dynamic = prim::AutogradAdd(%1, %8) + %13 : Float(2, 3, 4), %14 : Float(2, 3, 4) = prim::GradOf[name=aten::add](%12) block0() { - -> (%10, %10) + %15 : Float(2, 3, 4) = aten::mul(%12, %4) + -> (%12, %15) } - %13 : Dynamic = prim::AutogradAdd(%7, %12) - return (%13); + %16 : Dynamic = prim::AutogradAdd(%9, %14) + return (%16); } diff --git a/test/expect/TestJit.test_cse.expect b/test/expect/TestJit.test_cse.expect index b3d1a81a9929b8..46d9a4c6a17e0c 100644 --- a/test/expect/TestJit.test_cse.expect +++ b/test/expect/TestJit.test_cse.expect @@ -1,10 +1,11 @@ graph(%0 : Double(2) %1 : Double(2)) { - %2 : Double(2) = aten::add[alpha={1}](%0, %1) - %3 : Double(2) = aten::mul(%2, %2) - %4 : Double(2) = aten::mul(%3, %2) - %5 : Double(2) = aten::tanh(%4) - %6 : Double(2) = aten::add[alpha={1}](%5, %5) - %7 : Double(2) = aten::add[alpha={1}](%4, %6) - return (%7); + %2 : int = prim::Constant[value=1]() + %3 : Double(2) = aten::add(%0, %1, %2) + %4 : Double(2) = aten::mul(%3, %3) + %5 : Double(2) = aten::mul(%4, %3) + %6 : Double(2) = aten::tanh(%5) + %7 : Double(2) = aten::add(%6, %6, %2) + %8 : Double(2) = aten::add(%5, %7, %2) + return (%8); } diff --git a/test/expect/TestJit.test_decompose_addmm.expect b/test/expect/TestJit.test_decompose_addmm.expect index 925362f4f6a4ae..65a3e416d2b1e9 100644 --- a/test/expect/TestJit.test_decompose_addmm.expect +++ b/test/expect/TestJit.test_decompose_addmm.expect @@ -3,16 +3,23 @@ graph(%mat : Dynamic %mat2 : Dynamic %alpha : Dynamic %beta : Dynamic) { - %5 : Dynamic = aten::mm(%mat1, %mat2) - %6 : Dynamic = aten::add[alpha={1}](%mat, %5) - %7 : Dynamic = aten::mm(%mat1, %mat2) - %8 : Dynamic = aten::add[alpha={1}](%mat, %7) - %c : Dynamic = aten::addmm[beta={2}, alpha={4.2}](%mat, %mat1, %mat2) - %10 : int = prim::TensorToNum(%alpha) - %11 : int = prim::TensorToNum(%beta) - %d : Dynamic = aten::addmm(%mat, %mat1, %mat2, %11, %10) - %13 : Dynamic = aten::add[alpha={1}](%6, %8) - %14 : Dynamic = aten::add[alpha={1}](%13, %c) - %15 : Dynamic = aten::add[alpha={1}](%14, %d) - return (%15); + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=1]() + %a : Dynamic = aten::addmm(%mat, %mat1, %mat2, %5, %6) + %8 : float = prim::Constant[value=1]() + %9 : float = prim::Constant[value=1]() + %b : Dynamic = aten::addmm(%mat, %mat1, %mat2, %9, %8) + %11 : float = prim::Constant[value=4.2]() + %12 : float = prim::Constant[value=2]() + %c : Dynamic = aten::addmm(%mat, %mat1, %mat2, %12, %11) + %14 : int = prim::TensorToNum(%alpha) + %15 : int = prim::TensorToNum(%beta) + %d : Dynamic = aten::addmm(%mat, %mat1, %mat2, %15, %14) + %17 : int = prim::Constant[value=1]() + %18 : Dynamic = aten::add(%a, %b, %17) + %19 : int = prim::Constant[value=1]() + %20 : Dynamic = aten::add(%18, %c, %19) + %21 : int = prim::Constant[value=1]() + %22 : Dynamic = aten::add(%20, %d, %21) + return (%22); } diff --git a/test/expect/TestJit.test_fuse_last_device.expect b/test/expect/TestJit.test_fuse_last_device.expect index 276fadc61fd7df..e5613bfa975ffd 100644 --- a/test/expect/TestJit.test_fuse_last_device.expect +++ b/test/expect/TestJit.test_fuse_last_device.expect @@ -3,12 +3,14 @@ graph(%0 : Float(1) %2 : Float(1) = prim::FusionGroup_0[device=1](%0, %1) return (%2); } -with prim::FusionGroup_0 = graph(%6 : Float(1) - %9 : Float(1)) { - %10 : Float(1) = aten::add[alpha={1}](%6, %9) - %8 : Float(1) = aten::mul(%6, %10) - %5 : Float(1) = aten::add[other={1}, alpha={1}](%8) - %3 : Float(1) = aten::tanh(%5) +with prim::FusionGroup_0 = graph(%7 : Float(1) + %10 : Float(1)) { + %11 : int = prim::Constant[value=1]() + %12 : Float(1) = aten::add(%7, %10, %11) + %9 : Float(1) = aten::mul(%7, %12) + %5 : int = prim::Constant[value=1]() + %6 : Float(1) = aten::add(%9, %5, %5) + %3 : Float(1) = aten::tanh(%6) %1 : Float(1) = aten::sigmoid(%3) return (%1); } diff --git a/test/expect/TestJit.test_fusion_distribute.expect b/test/expect/TestJit.test_fusion_distribute.expect index 4465074e556585..380a92c8a112d0 100644 --- a/test/expect/TestJit.test_fusion_distribute.expect +++ b/test/expect/TestJit.test_fusion_distribute.expect @@ -1,16 +1,20 @@ graph(%0 : Float(4, 4) %1 : Float(4, 4)) { - %2 : Float(4!, 2), %3 : Float(4!, 2) = aten::chunk[chunks=2, dim=1](%0) - %4 : Float(4!, 2), %5 : Float(4!, 2) = aten::chunk[chunks=2, dim=1](%1) - %6 : Float(4, 2) = prim::FusionGroup_0[device=0](%2, %4, %3, %5) - return (%6); + %2 : int = prim::Constant[value=1]() + %3 : int = prim::Constant[value=2]() + %4 : Float(4!, 2), %5 : Float(4!, 2) = aten::chunk(%0, %3, %2) + %6 : Float(4!, 2), %7 : Float(4!, 2) = aten::chunk(%1, %3, %2) + %8 : Float(4, 2) = prim::FusionGroup_0[device=0](%4, %6, %5, %7) + return (%8); } with prim::FusionGroup_0 = graph(%3 : Float(4!, 2) %4 : Float(4!, 2) - %6 : Float(4!, 2) - %7 : Float(4!, 2)) { - %8 : Float(4, 2) = aten::add[alpha={1}](%6, %7) - %5 : Float(4, 2) = aten::add[alpha={1}](%3, %4) - %2 : Float(4, 2) = aten::mul(%5, %8) + %7 : Float(4!, 2) + %8 : Float(4!, 2)) { + %9 : int = prim::Constant[value=1]() + %10 : Float(4, 2) = aten::add(%7, %8, %9) + %5 : int = prim::Constant[value=1]() + %6 : Float(4, 2) = aten::add(%3, %4, %5) + %2 : Float(4, 2) = aten::mul(%6, %10) return (%2); } diff --git a/test/expect/TestJit.test_inplace_transplant.expect b/test/expect/TestJit.test_inplace_transplant.expect index e31e8c783b62b1..c9a84219a5ed6d 100644 --- a/test/expect/TestJit.test_inplace_transplant.expect +++ b/test/expect/TestJit.test_inplace_transplant.expect @@ -1,6 +1,10 @@ graph(%0 : Double(1)) { %1 : Double(1) = aten::clone(%0) - %2 : Double(1) = aten::add[other={2}, alpha={1}](%1) - %3 : Double(1) = aten::add[other={3}, alpha={1}](%2) - return (%3); + %2 : int = prim::Constant[value=2]() + %3 : int = prim::Constant[value=1]() + %4 : Double(1) = aten::add(%1, %2, %3) + %5 : int = prim::Constant[value=3]() + %6 : int = prim::Constant[value=1]() + %7 : Double(1) = aten::add(%4, %5, %6) + return (%7); } diff --git a/test/expect/TestJit.test_lstm_fusion_concat.expect b/test/expect/TestJit.test_lstm_fusion_concat.expect index 7f6b3f1c8b1b9c..7884a95c48c9a1 100644 --- a/test/expect/TestJit.test_lstm_fusion_concat.expect +++ b/test/expect/TestJit.test_lstm_fusion_concat.expect @@ -6,38 +6,44 @@ graph(%0 : Float(3, 10) %5 : Float(80) %6 : Float(80)) { %7 : Float(10!, 80!) = aten::t(%3) - %8 : Float(3!, 80) = aten::expand[size=[3, 80], implicit=0](%5) - %9 : Float(3, 80) = aten::addmm[alpha={1}, beta={1}](%8, %0, %7) + %8 : int = prim::Constant[value=1]() + %9 : Float(3, 80) = aten::addmm(%5, %0, %7, %8, %8) %10 : Float(20!, 80!) = aten::t(%4) - %11 : Float(3!, 80) = aten::expand[size=[3, 80], implicit=0](%6) - %12 : Float(3, 80) = aten::addmm[alpha={1}, beta={1}](%11, %1, %10) - %13 : Float(3!, 20), %14 : Float(3!, 20), %15 : Float(3!, 20), %16 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%9) - %17 : Float(3!, 20), %18 : Float(3!, 20), %19 : Float(3!, 20), %20 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%12) + %11 : Float(3, 80) = aten::addmm(%6, %1, %10, %8, %8) + %12 : int = prim::Constant[value=4]() + %13 : Float(3!, 20), %14 : Float(3!, 20), %15 : Float(3!, 20), %16 : Float(3!, 20) = aten::chunk(%9, %12, %8) + %17 : Float(3!, 20), %18 : Float(3!, 20), %19 : Float(3!, 20), %20 : Float(3!, 20) = aten::chunk(%11, %12, %8) %21 : Float(6, 20) = prim::FusionGroup_0[device=0](%2, %16, %20, %15, %19, %14, %18, %13, %17) return (%21); } -with prim::FusionGroup_0 = graph(%14 : Float(3, 20) - %24 : Float(3!, 20) - %25 : Float(3!, 20) +with prim::FusionGroup_0 = graph(%16 : Float(3, 20) + %26 : Float(3!, 20) %27 : Float(3!, 20) - %28 : Float(3!, 20) %30 : Float(3!, 20) %31 : Float(3!, 20) - %33 : Float(3!, 20) - %34 : Float(3!, 20)) { - %35 : Float(3, 20) = aten::add[alpha={1}](%33, %34) - %32 : Float(3, 20) = aten::add[alpha={1}](%30, %31) - %29 : Float(3, 20) = aten::add[alpha={1}](%27, %28) - %26 : Float(3, 20) = aten::add[alpha={1}](%24, %25) - %23 : Float(3, 20) = aten::sigmoid(%35) - %21 : Float(3, 20) = aten::sigmoid(%32) - %19 : Float(3, 20) = aten::tanh(%29) - %17 : Float(3, 20) = aten::sigmoid(%26) - %15 : Float(3, 20) = aten::mul(%21, %14) - %12 : Float(3, 20) = aten::mul(%23, %19) - %9 : Float(3, 20) = aten::add[alpha={1}](%15, %12) - %6 : Float(3, 20) = aten::tanh(%9) - %5 : Float(3, 20) = aten::mul(%17, %6) - %2 : Float(6, 20) = aten::cat[dim=0](%5, %9) - return (%2); + %34 : Float(3!, 20) + %35 : Float(3!, 20) + %38 : Float(3!, 20) + %39 : Float(3!, 20)) { + %40 : int = prim::Constant[value=1]() + %41 : Float(3, 20) = aten::add(%38, %39, %40) + %36 : int = prim::Constant[value=1]() + %37 : Float(3, 20) = aten::add(%34, %35, %36) + %32 : int = prim::Constant[value=1]() + %33 : Float(3, 20) = aten::add(%30, %31, %32) + %28 : int = prim::Constant[value=1]() + %29 : Float(3, 20) = aten::add(%26, %27, %28) + %25 : Float(3, 20) = aten::sigmoid(%41) + %23 : Float(3, 20) = aten::sigmoid(%37) + %21 : Float(3, 20) = aten::tanh(%33) + %19 : Float(3, 20) = aten::sigmoid(%29) + %17 : Float(3, 20) = aten::mul(%23, %16) + %14 : Float(3, 20) = aten::mul(%25, %21) + %10 : int = prim::Constant[value=1]() + %11 : Float(3, 20) = aten::add(%17, %14, %10) + %7 : Float(3, 20) = aten::tanh(%11) + %6 : Float(3, 20) = aten::mul(%19, %7) + %2 : int = prim::Constant[value=0]() + %3 : Float(6, 20) = aten::cat(%6, %11, %2) + return (%3); } diff --git a/test/expect/TestJit.test_lstm_fusion_cuda.expect b/test/expect/TestJit.test_lstm_fusion_cuda.expect index f2393996d11415..06be6cbb5d44a1 100644 --- a/test/expect/TestJit.test_lstm_fusion_cuda.expect +++ b/test/expect/TestJit.test_lstm_fusion_cuda.expect @@ -6,37 +6,42 @@ graph(%0 : Float(3, 10) %5 : Float(80) %6 : Float(80)) { %7 : Float(10!, 80!) = aten::t(%3) - %8 : Float(3!, 80) = aten::expand[size=[3, 80], implicit=0](%5) - %9 : Float(3, 80) = aten::addmm[alpha={1}, beta={1}](%8, %0, %7) + %8 : int = prim::Constant[value=1]() + %9 : Float(3, 80) = aten::addmm(%5, %0, %7, %8, %8) %10 : Float(20!, 80!) = aten::t(%4) - %11 : Float(3!, 80) = aten::expand[size=[3, 80], implicit=0](%6) - %12 : Float(3, 80) = aten::addmm[alpha={1}, beta={1}](%11, %1, %10) - %13 : Float(3!, 20), %14 : Float(3!, 20), %15 : Float(3!, 20), %16 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%9) - %17 : Float(3!, 20), %18 : Float(3!, 20), %19 : Float(3!, 20), %20 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%12) + %11 : Float(3, 80) = aten::addmm(%6, %1, %10, %8, %8) + %12 : int = prim::Constant[value=4]() + %13 : Float(3!, 20), %14 : Float(3!, 20), %15 : Float(3!, 20), %16 : Float(3!, 20) = aten::chunk(%9, %12, %8) + %17 : Float(3!, 20), %18 : Float(3!, 20), %19 : Float(3!, 20), %20 : Float(3!, 20) = aten::chunk(%11, %12, %8) %21 : Float(3, 20), %22 : Float(3, 20) = prim::FusionGroup_0[device=0](%2, %16, %20, %15, %19, %14, %18, %13, %17) return (%21, %22); } -with prim::FusionGroup_0 = graph(%12 : Float(3, 20) - %22 : Float(3!, 20) +with prim::FusionGroup_0 = graph(%13 : Float(3, 20) %23 : Float(3!, 20) - %25 : Float(3!, 20) - %26 : Float(3!, 20) + %24 : Float(3!, 20) + %27 : Float(3!, 20) %28 : Float(3!, 20) - %29 : Float(3!, 20) %31 : Float(3!, 20) - %32 : Float(3!, 20)) { - %33 : Float(3, 20) = aten::add[alpha={1}](%31, %32) - %30 : Float(3, 20) = aten::add[alpha={1}](%28, %29) - %27 : Float(3, 20) = aten::add[alpha={1}](%25, %26) - %24 : Float(3, 20) = aten::add[alpha={1}](%22, %23) - %21 : Float(3, 20) = aten::sigmoid(%33) - %19 : Float(3, 20) = aten::sigmoid(%30) - %17 : Float(3, 20) = aten::tanh(%27) - %15 : Float(3, 20) = aten::sigmoid(%24) - %13 : Float(3, 20) = aten::mul(%19, %12) - %10 : Float(3, 20) = aten::mul(%21, %17) - %7 : Float(3, 20) = aten::add[alpha={1}](%13, %10) - %4 : Float(3, 20) = aten::tanh(%7) - %2 : Float(3, 20) = aten::mul(%15, %4) - return (%2, %7); + %32 : Float(3!, 20) + %35 : Float(3!, 20) + %36 : Float(3!, 20)) { + %37 : int = prim::Constant[value=1]() + %38 : Float(3, 20) = aten::add(%35, %36, %37) + %33 : int = prim::Constant[value=1]() + %34 : Float(3, 20) = aten::add(%31, %32, %33) + %29 : int = prim::Constant[value=1]() + %30 : Float(3, 20) = aten::add(%27, %28, %29) + %25 : int = prim::Constant[value=1]() + %26 : Float(3, 20) = aten::add(%23, %24, %25) + %22 : Float(3, 20) = aten::sigmoid(%38) + %20 : Float(3, 20) = aten::sigmoid(%34) + %18 : Float(3, 20) = aten::tanh(%30) + %16 : Float(3, 20) = aten::sigmoid(%26) + %14 : Float(3, 20) = aten::mul(%20, %13) + %11 : Float(3, 20) = aten::mul(%22, %18) + %7 : int = prim::Constant[value=1]() + %8 : Float(3, 20) = aten::add(%14, %11, %7) + %4 : Float(3, 20) = aten::tanh(%8) + %2 : Float(3, 20) = aten::mul(%16, %4) + return (%2, %8); } diff --git a/test/expect/TestJit.test_nested_inplace.expect b/test/expect/TestJit.test_nested_inplace.expect index ff7e60b1c7d5ab..fd21055854faba 100644 --- a/test/expect/TestJit.test_nested_inplace.expect +++ b/test/expect/TestJit.test_nested_inplace.expect @@ -1,4 +1,6 @@ graph(%0 : Double(2, 2)) { - %1 : Double(2, 2) = aten::threshold[threshold={0}, value={0}](%0) - return (%1); + %1 : int = prim::Constant[value=0]() + %2 : int = prim::Constant[value=0]() + %3 : Double(2, 2) = aten::threshold(%0, %1, %2) + return (%3); } diff --git a/test/expect/TestJit.test_python_ir.expect b/test/expect/TestJit.test_python_ir.expect index 1bb094dcd8bfd3..59ed07b6fdc9f0 100644 --- a/test/expect/TestJit.test_python_ir.expect +++ b/test/expect/TestJit.test_python_ir.expect @@ -1,9 +1,10 @@ graph(%0 : Dynamic %1 : Dynamic) { - %2 : Double(1) = aten::add[alpha={1}](%0, %1) - %3 : Double(1) = aten::mul(%0, %2) - %4 : Double(1) = aten::tanh(%3) - %5 : Double(1) = aten::sigmoid(%4) - %6 : Dynamic = prim::TensorTest[a= 1 1 1 1 [ CPUDoubleType{2,2} ]]() - return (%5); + %2 : int = prim::Constant[value=1]() + %3 : Double(1) = aten::add(%0, %1, %2) + %4 : Double(1) = aten::mul(%0, %3) + %5 : Double(1) = aten::tanh(%4) + %6 : Double(1) = aten::sigmoid(%5) + %7 : Dynamic = prim::TensorTest[a= 1 1 1 1 [ CPUDoubleType{2,2} ]]() + return (%6); } diff --git a/test/expect/TestJit.test_repeated_input.expect b/test/expect/TestJit.test_repeated_input.expect index 57e57066ef503b..ac67a6c14fc972 100644 --- a/test/expect/TestJit.test_repeated_input.expect +++ b/test/expect/TestJit.test_repeated_input.expect @@ -1,5 +1,6 @@ graph(%0 : Double(2, 2) %1 : Double(2, 2)) { - %2 : Double(2, 2) = aten::add[alpha={1}](%0, %1) - return (%2); + %2 : int = prim::Constant[value=1]() + %3 : Double(2, 2) = aten::add(%0, %1, %2) + return (%3); } diff --git a/test/expect/TestJit.test_repeated_output.expect b/test/expect/TestJit.test_repeated_output.expect index b3baff631ebe0d..64a937aef7fb6a 100644 --- a/test/expect/TestJit.test_repeated_output.expect +++ b/test/expect/TestJit.test_repeated_output.expect @@ -1,5 +1,6 @@ graph(%0 : Double(2, 2) %1 : Double(2, 2)) { - %2 : Double(2, 2) = aten::add[alpha={1}](%0, %1) - return (%2, %2); + %2 : int = prim::Constant[value=1]() + %3 : Double(2, 2) = aten::add(%0, %1, %2) + return (%3, %3); } diff --git a/test/expect/TestJit.test_scopes.expect b/test/expect/TestJit.test_scopes.expect index 05578370f5cfbf..3cbbb0b966afe1 100644 --- a/test/expect/TestJit.test_scopes.expect +++ b/test/expect/TestJit.test_scopes.expect @@ -1,8 +1,9 @@ graph(%0 : Double(1) %1 : Double(1)) { - %2 : Double(1) = aten::add[alpha={1}](%0, %1) - %3 : Double(1) = aten::mul(%0, %2), scope: Foo - %4 : Double(1) = aten::tanh(%3), scope: Foo/Bar - %5 : Double(1) = aten::sigmoid(%4), scope: Foo - return (%5); + %2 : int = prim::Constant[value=1]() + %3 : Double(1) = aten::add(%0, %1, %2) + %4 : Double(1) = aten::mul(%0, %3), scope: Foo + %5 : Double(1) = aten::tanh(%4), scope: Foo/Bar + %6 : Double(1) = aten::sigmoid(%5), scope: Foo + return (%6); } diff --git a/test/expect/TestJit.test_shape_analysis_broadcast.expect b/test/expect/TestJit.test_shape_analysis_broadcast.expect index bbe5b741649d0a..e238c3fe1adc13 100644 --- a/test/expect/TestJit.test_shape_analysis_broadcast.expect +++ b/test/expect/TestJit.test_shape_analysis_broadcast.expect @@ -1,7 +1,12 @@ graph(%a : Double(3, 1, 5) %b : Double(4, 1, 8, 5)) { - %2 : Double(4!, 3!, 8!, 5) = aten::expand[size=[4, 3, 8, 5], implicit=0](%a) - %3 : Double(4!, 3!, 8, 5) = aten::expand[size=[4, 3, 8, 5], implicit=0](%b) - %4 : Double(4, 3, 8, 5) = aten::add[alpha={1}](%2, %3) - return (%4); + %2 : int = prim::Constant[value=1]() + %3 : int[] = prim::Constant[value=[4, 3, 8, 5]]() + %4 : int = prim::Constant[value=0]() + %5 : Double(4!, 3!, 8!, 5) = aten::expand(%a, %3, %4) + %6 : int[] = prim::Constant[value=[4, 3, 8, 5]]() + %7 : int = prim::Constant[value=0]() + %8 : Double(4!, 3!, 8, 5) = aten::expand(%b, %6, %7) + %9 : Double(4, 3, 8, 5) = aten::add(%5, %8, %2) + return (%9); } diff --git a/test/expect/TestJit.test_shared_param.expect b/test/expect/TestJit.test_shared_param.expect index ec758dfb7d87af..1b0a2c25be34bb 100644 --- a/test/expect/TestJit.test_shared_param.expect +++ b/test/expect/TestJit.test_shared_param.expect @@ -1,6 +1,7 @@ graph(%0 : Double(2, 2) %1 : Double(2, 2)) { %2 : Double(2, 2) = aten::mul(%0, %1), scope: MyModule - %3 : Double(2, 2) = aten::add[alpha={1}](%2, %1), scope: MyModule - return (%3); + %3 : int = prim::Constant[value=1](), scope: MyModule + %4 : Double(2, 2) = aten::add(%2, %1, %3), scope: MyModule + return (%4); } diff --git a/test/expect/TestJit.test_simple.expect b/test/expect/TestJit.test_simple.expect index 1db84de676b3e5..bfa7408b6be17c 100644 --- a/test/expect/TestJit.test_simple.expect +++ b/test/expect/TestJit.test_simple.expect @@ -1,8 +1,9 @@ graph(%0 : Double(1) %1 : Double(1)) { - %2 : Double(1) = aten::add[alpha={1}](%0, %1) - %3 : Double(1) = aten::mul(%0, %2) - %4 : Double(1) = aten::tanh(%3) - %5 : Double(1) = aten::sigmoid(%4) - return (%5); + %2 : int = prim::Constant[value=1]() + %3 : Double(1) = aten::add(%0, %1, %2) + %4 : Double(1) = aten::mul(%0, %3) + %5 : Double(1) = aten::tanh(%4) + %6 : Double(1) = aten::sigmoid(%5) + return (%6); } diff --git a/test/expect/TestJit.test_trace_size.expect b/test/expect/TestJit.test_trace_size.expect index 567a0fc5a5ecb3..8068691735c3dd 100644 --- a/test/expect/TestJit.test_trace_size.expect +++ b/test/expect/TestJit.test_trace_size.expect @@ -2,14 +2,15 @@ graph(%0 : Double(5, 2, 4)) { %1 : int = prim::Constant[value=1]() %2 : int = aten::size(%0, %1) %3 : Long() = prim::NumToTensor(%2) - %4 : Long() = aten::mul[other={2}](%3) - %5 : int = prim::TensorToNum(%4) - %6 : int = prim::Constant[value=0]() - %7 : int = aten::size(%0, %6) - %8 : Long() = prim::NumToTensor(%7) - %9 : int = prim::TensorToNum(%8) - %10 : int = prim::Constant[value=2]() - %11 : int[] = prim::ListConstruct(%5, %9, %10) - %12 : Double(4, 5, 2) = aten::view(%0, %11) - return (%12); + %4 : int = prim::Constant[value=2]() + %5 : Long() = aten::mul(%3, %4) + %6 : int = prim::TensorToNum(%5) + %7 : int = prim::Constant[value=0]() + %8 : int = aten::size(%0, %7) + %9 : Long() = prim::NumToTensor(%8) + %10 : int = prim::TensorToNum(%9) + %11 : int = prim::Constant[value=2]() + %12 : int[] = prim::ListConstruct(%6, %10, %11) + %13 : Double(4, 5, 2) = aten::view(%0, %12) + return (%13); } diff --git a/test/expect/TestJit.test_trace_size_with_grad.expect b/test/expect/TestJit.test_trace_size_with_grad.expect index 567a0fc5a5ecb3..8068691735c3dd 100644 --- a/test/expect/TestJit.test_trace_size_with_grad.expect +++ b/test/expect/TestJit.test_trace_size_with_grad.expect @@ -2,14 +2,15 @@ graph(%0 : Double(5, 2, 4)) { %1 : int = prim::Constant[value=1]() %2 : int = aten::size(%0, %1) %3 : Long() = prim::NumToTensor(%2) - %4 : Long() = aten::mul[other={2}](%3) - %5 : int = prim::TensorToNum(%4) - %6 : int = prim::Constant[value=0]() - %7 : int = aten::size(%0, %6) - %8 : Long() = prim::NumToTensor(%7) - %9 : int = prim::TensorToNum(%8) - %10 : int = prim::Constant[value=2]() - %11 : int[] = prim::ListConstruct(%5, %9, %10) - %12 : Double(4, 5, 2) = aten::view(%0, %11) - return (%12); + %4 : int = prim::Constant[value=2]() + %5 : Long() = aten::mul(%3, %4) + %6 : int = prim::TensorToNum(%5) + %7 : int = prim::Constant[value=0]() + %8 : int = aten::size(%0, %7) + %9 : Long() = prim::NumToTensor(%8) + %10 : int = prim::TensorToNum(%9) + %11 : int = prim::Constant[value=2]() + %12 : int[] = prim::ListConstruct(%6, %10, %11) + %13 : Double(4, 5, 2) = aten::view(%0, %12) + return (%13); } diff --git a/test/expect/TestScript.test_call_python_fn_from_script_fn.expect b/test/expect/TestScript.test_call_python_fn_from_script_fn.expect index db478d2e22f9cb..297d0918700a02 100644 --- a/test/expect/TestScript.test_call_python_fn_from_script_fn.expect +++ b/test/expect/TestScript.test_call_python_fn_from_script_fn.expect @@ -1,5 +1,7 @@ graph(%x : Dynamic) { %1 : Dynamic = ^python_fn()(%x) - %5 : Dynamic = aten::add[other={1}, alpha={1}](%1) + %2 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=1]() + %5 : Dynamic = aten::add(%1, %2, %4) return (%5); } diff --git a/test/expect/TestScript.test_call_python_fn_from_tracing_fn.expect b/test/expect/TestScript.test_call_python_fn_from_tracing_fn.expect index 4eb19bbc83c7e7..ac76985db76fb3 100644 --- a/test/expect/TestScript.test_call_python_fn_from_tracing_fn.expect +++ b/test/expect/TestScript.test_call_python_fn_from_tracing_fn.expect @@ -1,5 +1,7 @@ graph(%0 : Double(3, 4)) { %1 : Double(3, 4) = aten::neg(%0) - %2 : Double(3, 4) = aten::add[other={1}, alpha={1}](%1) - return (%2); + %2 : int = prim::Constant[value=1]() + %3 : int = prim::Constant[value=1]() + %4 : Double(3, 4) = aten::add(%1, %2, %3) + return (%4); } diff --git a/test/expect/TestScript.test_call_python_mod_from_script_fn.expect b/test/expect/TestScript.test_call_python_mod_from_script_fn.expect index ec5fd842f3b864..260bbaba6462f7 100644 --- a/test/expect/TestScript.test_call_python_mod_from_script_fn.expect +++ b/test/expect/TestScript.test_call_python_mod_from_script_fn.expect @@ -1,5 +1,7 @@ graph(%x : Dynamic) { %1 : Dynamic = ^()(%x) - %5 : Dynamic = aten::add[other={1}, alpha={1}](%1) + %2 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=1]() + %5 : Dynamic = aten::add(%1, %2, %4) return (%5); } diff --git a/test/expect/TestScript.test_call_python_mod_from_traced_module.expect b/test/expect/TestScript.test_call_python_mod_from_traced_module.expect index d39acaf5257d3d..863cbdf2d5a4fa 100644 --- a/test/expect/TestScript.test_call_python_mod_from_traced_module.expect +++ b/test/expect/TestScript.test_call_python_mod_from_traced_module.expect @@ -3,6 +3,8 @@ graph(%0 : Double(3, 4) %2 : Double(5, 7)) { %4 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule %6 : Double(3, 7) = aten::mm(%4, %2), scope: TracedModule/PythonModule[mod] - %7 : Double(3, 7) = aten::add[other={1}, alpha={1}](%6), scope: TracedModule - return (%7); + %7 : int = prim::Constant[value=1](), scope: TracedModule + %8 : int = prim::Constant[value=1](), scope: TracedModule + %9 : Double(3, 7) = aten::add(%6, %7, %8), scope: TracedModule + return (%9); } diff --git a/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect index ea847d630c8ba5..a23a6bd2730368 100644 --- a/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect +++ b/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect @@ -1,6 +1,8 @@ graph(%0 : Double(3, 4)) { %1 : Double(4, 3) = prim::Constant[value=](), scope: PythonMod %3 : Double(3, 3) = aten::mm(%0, %1), scope: PythonMod - %4 : Double(3, 3) = aten::add[other={1}, alpha={1}](%3) - return (%4); + %4 : int = prim::Constant[value=1]() + %5 : int = prim::Constant[value=1]() + %6 : Double(3, 3) = aten::add(%3, %4, %5) + return (%6); } diff --git a/test/expect/TestScript.test_call_script_fn_from_script_fn.expect b/test/expect/TestScript.test_call_script_fn_from_script_fn.expect index e36a68926dccce..8c23ad8f353bef 100644 --- a/test/expect/TestScript.test_call_script_fn_from_script_fn.expect +++ b/test/expect/TestScript.test_call_script_fn_from_script_fn.expect @@ -1,5 +1,7 @@ graph(%x : Dynamic) { %1 : Dynamic = aten::neg(%x) - %5 : Dynamic = aten::add[other={1}, alpha={1}](%1) + %2 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=1]() + %5 : Dynamic = aten::add(%1, %2, %4) return (%5); } diff --git a/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect b/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect index dc8b4945df4773..d12a6a40520bdf 100644 --- a/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect +++ b/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect @@ -1,5 +1,7 @@ graph(%0 : Double(3, 4)) { %2 : Double(3, 4) = aten::neg(%0), scope: ScriptModule - %3 : Double(3, 4) = aten::add[other={1}, alpha={1}](%2) - return (%3); + %3 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=1]() + %5 : Double(3, 4) = aten::add(%2, %3, %4) + return (%5); } diff --git a/test/expect/TestScript.test_call_script_mod_from_script_fn.expect b/test/expect/TestScript.test_call_script_mod_from_script_fn.expect index e24d034b26e3da..98cf4ade03b461 100644 --- a/test/expect/TestScript.test_call_script_mod_from_script_fn.expect +++ b/test/expect/TestScript.test_call_script_mod_from_script_fn.expect @@ -7,6 +7,8 @@ graph(%x : Dynamic) { %6 : int[] = prim::ListConstruct(%1, %2) %7 : Dynamic = aten::zeros(%6, %3, %4, %5) %8 : Dynamic = aten::mm(%x, %7) - %12 : Dynamic = aten::add[other={1}, alpha={1}](%8) + %9 : int = prim::Constant[value=1]() + %11 : int = prim::Constant[value=1]() + %12 : Dynamic = aten::add(%8, %9, %11) return (%12); } diff --git a/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect index fc7039bd971f23..4c098d3b9f16a4 100644 --- a/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect +++ b/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect @@ -1,6 +1,8 @@ graph(%0 : Double(3, 4)) { %1 : Double(4, 3) = prim::Constant[value=](), scope: ScriptMod %4 : Double(3, 3) = aten::mm(%0, %1), scope: ScriptMod - %5 : Double(3, 3) = aten::add[other={1}, alpha={1}](%4) - return (%5); + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=1]() + %7 : Double(3, 3) = aten::add(%4, %5, %6) + return (%7); } diff --git a/test/expect/TestScript.test_call_script_module_from_traced_module.expect b/test/expect/TestScript.test_call_script_module_from_traced_module.expect index 21b14a2a62f8cf..1a452935eb5fc8 100644 --- a/test/expect/TestScript.test_call_script_module_from_traced_module.expect +++ b/test/expect/TestScript.test_call_script_module_from_traced_module.expect @@ -3,6 +3,8 @@ graph(%0 : Double(3, 4) %2 : Double(5, 7)) { %4 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule %7 : Double(3, 7) = aten::mm(%4, %2), scope: TracedModule/ScriptMod[mod] - %8 : Double(3, 7) = aten::add[other={1}, alpha={1}](%7), scope: TracedModule - return (%8); + %8 : int = prim::Constant[value=1](), scope: TracedModule + %9 : int = prim::Constant[value=1](), scope: TracedModule + %10 : Double(3, 7) = aten::add(%7, %8, %9), scope: TracedModule + return (%10); } diff --git a/test/expect/TestScript.test_call_traced_fn_from_script_fn.expect b/test/expect/TestScript.test_call_traced_fn_from_script_fn.expect index 83ce62e68e086d..2cae32e2be01a7 100644 --- a/test/expect/TestScript.test_call_traced_fn_from_script_fn.expect +++ b/test/expect/TestScript.test_call_traced_fn_from_script_fn.expect @@ -1,5 +1,7 @@ graph(%x : Dynamic) { %1 : Double(3, 4) = aten::neg(%x) - %5 : Dynamic = aten::add[other={1}, alpha={1}](%1) + %2 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=1]() + %5 : Dynamic = aten::add(%1, %2, %4) return (%5); } diff --git a/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect b/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect index ed737f4b6580b4..27eb2b6e7814e0 100644 --- a/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect +++ b/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect @@ -1,5 +1,7 @@ graph(%0 : Double(3, 4)) { %2 : Double(3, 4) = aten::neg(%0), scope: traced_fn1 - %3 : Double(3, 4) = aten::add[other={1}, alpha={1}](%2) - return (%3); + %3 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=1]() + %5 : Double(3, 4) = aten::add(%2, %3, %4) + return (%5); } diff --git a/test/expect/TestScript.test_call_traced_mod_from_script_fn.expect b/test/expect/TestScript.test_call_traced_mod_from_script_fn.expect index 9a99fbe83f1d4d..315ec3464487be 100644 --- a/test/expect/TestScript.test_call_traced_mod_from_script_fn.expect +++ b/test/expect/TestScript.test_call_traced_mod_from_script_fn.expect @@ -1,6 +1,8 @@ graph(%x : Dynamic) { %1 : Double(4, 3) = prim::Constant[value=]() %2 : Double(3, 3) = aten::mm(%x, %1) - %6 : Dynamic = aten::add[other={1}, alpha={1}](%2) + %3 : int = prim::Constant[value=1]() + %5 : int = prim::Constant[value=1]() + %6 : Dynamic = aten::add(%2, %3, %5) return (%6); } diff --git a/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect index 3fac45fc2dfdab..f5c6f1bb2c18d8 100644 --- a/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect +++ b/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect @@ -1,6 +1,8 @@ graph(%0 : Double(3, 4)) { %1 : Double(4, 3) = prim::Constant[value=](), scope: TracedModule[TracedModule] %4 : Double(3, 3) = aten::mm(%0, %1), scope: TracedModule[TracedModule] - %5 : Double(3, 3) = aten::add[other={1}, alpha={1}](%4) - return (%5); + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=1]() + %7 : Double(3, 3) = aten::add(%4, %5, %6) + return (%7); } diff --git a/test/expect/TestScript.test_call_traced_module_from_traced_module.expect b/test/expect/TestScript.test_call_traced_module_from_traced_module.expect index 471f9f1c2ec3fe..f66573f6da2f25 100644 --- a/test/expect/TestScript.test_call_traced_module_from_traced_module.expect +++ b/test/expect/TestScript.test_call_traced_module_from_traced_module.expect @@ -3,6 +3,8 @@ graph(%0 : Double(3, 4) %2 : Double(5, 7)) { %4 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule %7 : Double(3, 7) = aten::mm(%4, %2), scope: TracedModule/TracedModule[TracedModule1][mod] - %8 : Double(3, 7) = aten::add[other={1}, alpha={1}](%7), scope: TracedModule - return (%8); + %8 : int = prim::Constant[value=1](), scope: TracedModule + %9 : int = prim::Constant[value=1](), scope: TracedModule + %10 : Double(3, 7) = aten::add(%7, %8, %9), scope: TracedModule + return (%10); } diff --git a/test/expect/TestScript.test_cat_lifts.expect b/test/expect/TestScript.test_cat_lifts.expect index 5bcef43f7c7a3d..ea2fa3737c0556 100644 --- a/test/expect/TestScript.test_cat_lifts.expect +++ b/test/expect/TestScript.test_cat_lifts.expect @@ -1,12 +1,15 @@ graph(%x : Dynamic) { - %1 : Dynamic = aten::cat[dim=1](%x, %x) - return (%1); + %1 : int = prim::Constant[value=1]() + %2 : Dynamic = aten::cat(%x, %x, %1) + return (%2); } graph(%x : Dynamic) { - %1 : Dynamic = aten::cat[dim=1]() - return (%1); + %1 : int = prim::Constant[value=1]() + %2 : Dynamic = aten::cat(%1) + return (%2); } graph(%x : Dynamic) { - %1 : Dynamic = aten::cat[dim=1](%x) - return (%1); + %1 : int = prim::Constant[value=1]() + %2 : Dynamic = aten::cat(%x, %1) + return (%2); } diff --git a/test/expect/TestScript.test_index_put_trace_with_view.expect b/test/expect/TestScript.test_index_put_trace_with_view.expect index 24ff0fe32c451f..591e499da96671 100644 --- a/test/expect/TestScript.test_index_put_trace_with_view.expect +++ b/test/expect/TestScript.test_index_put_trace_with_view.expect @@ -1,8 +1,11 @@ graph(%0 : Double(100) %1 : Long(4) %2 : Double(1, 1, 1, 4)) { - %3 : Double(4) = aten::view[size=[4]](%2) - %4 : Long(4) = aten::_cast_Long[non_blocking=0](%1) - %11 : Double(100) = aten::index_put(%0, %4, %3) - return (%11); + %3 : int = prim::Constant[value=4]() + %4 : int[] = prim::ListConstruct(%3) + %5 : Double(4) = aten::view(%2, %4) + %6 : int = prim::Constant[value=0]() + %7 : Long(4) = aten::_cast_Long(%1, %6) + %19 : Double(100) = aten::index_put(%0, %7, %5) + return (%19); } diff --git a/test/expect/TestScript.test_index_put_trace_without_view.expect b/test/expect/TestScript.test_index_put_trace_without_view.expect index f483213b481461..42f8e49142942e 100644 --- a/test/expect/TestScript.test_index_put_trace_without_view.expect +++ b/test/expect/TestScript.test_index_put_trace_without_view.expect @@ -1,7 +1,8 @@ graph(%0 : Double(100) %1 : Long(4) %2 : Double(4)) { - %3 : Long(4) = aten::_cast_Long[non_blocking=0](%1) - %10 : Double(100) = aten::index_put(%0, %3, %2) - return (%10); + %3 : int = prim::Constant[value=0]() + %4 : Long(4) = aten::_cast_Long(%1, %3) + %16 : Double(100) = aten::index_put(%0, %4, %2) + return (%16); } diff --git a/test/expect/TestScript.test_index_select_shape_prop.expect b/test/expect/TestScript.test_index_select_shape_prop.expect index 32a9d7744e52cc..f24249a21f9d20 100644 --- a/test/expect/TestScript.test_index_select_shape_prop.expect +++ b/test/expect/TestScript.test_index_select_shape_prop.expect @@ -1,5 +1,6 @@ graph(%x : Double(2, 2) %y : Long(4)) { - %2 : Double(2, 4) = aten::index_select[dim=1](%x, %y) - return (%2); + %2 : int = prim::Constant[value=1]() + %3 : Dynamic = aten::index_select(%x, %2, %y) + return (%3); } diff --git a/test/expect/TestScript.test_loop_unroll_unused_counter.expect b/test/expect/TestScript.test_loop_unroll_unused_counter.expect index a4b5983c1e5d9c..be1a5efeecf449 100644 --- a/test/expect/TestScript.test_loop_unroll_unused_counter.expect +++ b/test/expect/TestScript.test_loop_unroll_unused_counter.expect @@ -9,22 +9,40 @@ graph(%x : Dynamic) { %8 : int = aten::sub(%2, %7) %y.3 : Dynamic = prim::Loop(%5, %3, %y.1) block0(%i.1 : int, %11 : Dynamic) { - %y.12 : Dynamic = aten::add[other={1}, alpha={1}](%11) - %y.5 : Dynamic = aten::add[other={1}, alpha={1}](%y.12) - %y.6 : Dynamic = aten::add[other={1}, alpha={1}](%y.5) - %y.7 : Dynamic = aten::add[other={1}, alpha={1}](%y.6) - %y.8 : Dynamic = aten::add[other={1}, alpha={1}](%y.7) - %y.9 : Dynamic = aten::add[other={1}, alpha={1}](%y.8) - %y.10 : Dynamic = aten::add[other={1}, alpha={1}](%y.9) - %y.11 : Dynamic = aten::add[other={1}, alpha={1}](%y.10) - %20 : int = prim::Constant[value=1]() - -> (%20, %y.11) + %12 : int = prim::Constant[value=1]() + %13 : int = prim::Constant[value=1]() + %y.12 : Dynamic = aten::add(%11, %12, %13) + %15 : int = prim::Constant[value=1]() + %16 : int = prim::Constant[value=1]() + %y.5 : Dynamic = aten::add(%y.12, %15, %16) + %18 : int = prim::Constant[value=1]() + %19 : int = prim::Constant[value=1]() + %y.6 : Dynamic = aten::add(%y.5, %18, %19) + %21 : int = prim::Constant[value=1]() + %22 : int = prim::Constant[value=1]() + %y.7 : Dynamic = aten::add(%y.6, %21, %22) + %24 : int = prim::Constant[value=1]() + %25 : int = prim::Constant[value=1]() + %y.8 : Dynamic = aten::add(%y.7, %24, %25) + %27 : int = prim::Constant[value=1]() + %28 : int = prim::Constant[value=1]() + %y.9 : Dynamic = aten::add(%y.8, %27, %28) + %30 : int = prim::Constant[value=1]() + %31 : int = prim::Constant[value=1]() + %y.10 : Dynamic = aten::add(%y.9, %30, %31) + %33 : int = prim::Constant[value=1]() + %34 : int = prim::Constant[value=1]() + %y.11 : Dynamic = aten::add(%y.10, %33, %34) + %36 : int = prim::Constant[value=1]() + -> (%36, %y.11) } %y : Dynamic = prim::Loop(%8, %3, %y.3) - block0(%i : int, %23 : Dynamic) { - %y.4 : Dynamic = aten::add[other={1}, alpha={1}](%23) - %25 : int = prim::Constant[value=1]() - -> (%25, %y.4) + block0(%i : int, %39 : Dynamic) { + %40 : int = prim::Constant[value=1]() + %41 : int = prim::Constant[value=1]() + %y.4 : Dynamic = aten::add(%39, %40, %41) + %43 : int = prim::Constant[value=1]() + -> (%43, %y.4) } return (%y); } diff --git a/test/expect/TestScript.test_loop_unrolling.expect b/test/expect/TestScript.test_loop_unrolling.expect index 0c77a4ec47e6ec..fc0ca446112036 100644 --- a/test/expect/TestScript.test_loop_unrolling.expect +++ b/test/expect/TestScript.test_loop_unrolling.expect @@ -10,35 +10,35 @@ graph(%x : Dynamic) { %9 : int = aten::sub(%2, %8) %10 : Dynamic, %y.3 : Dynamic = prim::Loop(%6, %3, %4, %y.1) block0(%i.1 : int, %13 : Dynamic, %14 : Dynamic) { - %15 : Number = prim::Constant[value=1]() + %15 : int = prim::Constant[value=1]() %y.12 : Dynamic = aten::add(%14, %13, %15) %17 : int = prim::Constant[value=1]() %18 : int = aten::add(%13, %17) - %19 : Number = prim::Constant[value=1]() + %19 : int = prim::Constant[value=1]() %y.5 : Dynamic = aten::add(%y.12, %18, %19) %21 : int = prim::Constant[value=1]() %22 : int = aten::add(%18, %21) - %23 : Number = prim::Constant[value=1]() + %23 : int = prim::Constant[value=1]() %y.6 : Dynamic = aten::add(%y.5, %22, %23) %25 : int = prim::Constant[value=1]() %26 : int = aten::add(%22, %25) - %27 : Number = prim::Constant[value=1]() + %27 : int = prim::Constant[value=1]() %y.7 : Dynamic = aten::add(%y.6, %26, %27) %29 : int = prim::Constant[value=1]() %30 : int = aten::add(%26, %29) - %31 : Number = prim::Constant[value=1]() + %31 : int = prim::Constant[value=1]() %y.8 : Dynamic = aten::add(%y.7, %30, %31) %33 : int = prim::Constant[value=1]() %34 : int = aten::add(%30, %33) - %35 : Number = prim::Constant[value=1]() + %35 : int = prim::Constant[value=1]() %y.9 : Dynamic = aten::add(%y.8, %34, %35) %37 : int = prim::Constant[value=1]() %38 : int = aten::add(%34, %37) - %39 : Number = prim::Constant[value=1]() + %39 : int = prim::Constant[value=1]() %y.10 : Dynamic = aten::add(%y.9, %38, %39) %41 : int = prim::Constant[value=1]() %42 : int = aten::add(%38, %41) - %43 : Number = prim::Constant[value=1]() + %43 : int = prim::Constant[value=1]() %y.11 : Dynamic = aten::add(%y.10, %42, %43) %45 : int = prim::Constant[value=1]() %46 : int = prim::Constant[value=1]() @@ -47,7 +47,7 @@ graph(%x : Dynamic) { } %48 : Dynamic, %y : Dynamic = prim::Loop(%9, %3, %10, %y.3) block0(%i : int, %51 : Dynamic, %52 : Dynamic) { - %53 : Number = prim::Constant[value=1]() + %53 : int = prim::Constant[value=1]() %y.4 : Dynamic = aten::add(%52, %51, %53) %55 : int = prim::Constant[value=1]() %56 : int = prim::Constant[value=1]() diff --git a/test/expect/TestScript.test_loop_unrolling_const-add_const.expect b/test/expect/TestScript.test_loop_unrolling_const-add_const.expect index 8f810b0a6339bf..6b7d615a1b7800 100644 --- a/test/expect/TestScript.test_loop_unrolling_const-add_const.expect +++ b/test/expect/TestScript.test_loop_unrolling_const-add_const.expect @@ -1,14 +1,34 @@ graph() { %y.1 : Dynamic = ^FIXME_zerol()() - %y.11 : Dynamic = aten::add[other={1}, alpha={1}](%y.1) - %y.2 : Dynamic = aten::add[other={1}, alpha={1}](%y.11) - %y.3 : Dynamic = aten::add[other={1}, alpha={1}](%y.2) - %y.4 : Dynamic = aten::add[other={1}, alpha={1}](%y.3) - %y.5 : Dynamic = aten::add[other={1}, alpha={1}](%y.4) - %y.6 : Dynamic = aten::add[other={1}, alpha={1}](%y.5) - %y.7 : Dynamic = aten::add[other={1}, alpha={1}](%y.6) - %y.8 : Dynamic = aten::add[other={1}, alpha={1}](%y.7) - %y.9 : Dynamic = aten::add[other={1}, alpha={1}](%y.8) - %y.10 : Dynamic = aten::add[other={1}, alpha={1}](%y.9) + %1 : int = prim::Constant[value=1]() + %2 : int = prim::Constant[value=1]() + %y.11 : Dynamic = aten::add(%y.1, %1, %2) + %4 : int = prim::Constant[value=1]() + %5 : int = prim::Constant[value=1]() + %y.2 : Dynamic = aten::add(%y.11, %4, %5) + %7 : int = prim::Constant[value=1]() + %8 : int = prim::Constant[value=1]() + %y.3 : Dynamic = aten::add(%y.2, %7, %8) + %10 : int = prim::Constant[value=1]() + %11 : int = prim::Constant[value=1]() + %y.4 : Dynamic = aten::add(%y.3, %10, %11) + %13 : int = prim::Constant[value=1]() + %14 : int = prim::Constant[value=1]() + %y.5 : Dynamic = aten::add(%y.4, %13, %14) + %16 : int = prim::Constant[value=1]() + %17 : int = prim::Constant[value=1]() + %y.6 : Dynamic = aten::add(%y.5, %16, %17) + %19 : int = prim::Constant[value=1]() + %20 : int = prim::Constant[value=1]() + %y.7 : Dynamic = aten::add(%y.6, %19, %20) + %22 : int = prim::Constant[value=1]() + %23 : int = prim::Constant[value=1]() + %y.8 : Dynamic = aten::add(%y.7, %22, %23) + %25 : int = prim::Constant[value=1]() + %26 : int = prim::Constant[value=1]() + %y.9 : Dynamic = aten::add(%y.8, %25, %26) + %28 : int = prim::Constant[value=1]() + %29 : int = prim::Constant[value=1]() + %y.10 : Dynamic = aten::add(%y.9, %28, %29) return (%y.10); } diff --git a/test/expect/TestScript.test_loop_unrolling_const-add_iter.expect b/test/expect/TestScript.test_loop_unrolling_const-add_iter.expect index 2618493dc8ecb4..ba142cc8092cd3 100644 --- a/test/expect/TestScript.test_loop_unrolling_const-add_iter.expect +++ b/test/expect/TestScript.test_loop_unrolling_const-add_iter.expect @@ -1,43 +1,43 @@ graph() { %y.1 : Dynamic = ^FIXME_zerol()() %1 : int = prim::Constant[value=0]() - %2 : Number = prim::Constant[value=1]() + %2 : int = prim::Constant[value=1]() %y.11 : Dynamic = aten::add(%y.1, %1, %2) %4 : int = prim::Constant[value=1]() %5 : int = aten::add(%1, %4) - %6 : Number = prim::Constant[value=1]() + %6 : int = prim::Constant[value=1]() %y.2 : Dynamic = aten::add(%y.11, %5, %6) %8 : int = prim::Constant[value=1]() %9 : int = aten::add(%5, %8) - %10 : Number = prim::Constant[value=1]() + %10 : int = prim::Constant[value=1]() %y.3 : Dynamic = aten::add(%y.2, %9, %10) %12 : int = prim::Constant[value=1]() %13 : int = aten::add(%9, %12) - %14 : Number = prim::Constant[value=1]() + %14 : int = prim::Constant[value=1]() %y.4 : Dynamic = aten::add(%y.3, %13, %14) %16 : int = prim::Constant[value=1]() %17 : int = aten::add(%13, %16) - %18 : Number = prim::Constant[value=1]() + %18 : int = prim::Constant[value=1]() %y.5 : Dynamic = aten::add(%y.4, %17, %18) %20 : int = prim::Constant[value=1]() %21 : int = aten::add(%17, %20) - %22 : Number = prim::Constant[value=1]() + %22 : int = prim::Constant[value=1]() %y.6 : Dynamic = aten::add(%y.5, %21, %22) %24 : int = prim::Constant[value=1]() %25 : int = aten::add(%21, %24) - %26 : Number = prim::Constant[value=1]() + %26 : int = prim::Constant[value=1]() %y.7 : Dynamic = aten::add(%y.6, %25, %26) %28 : int = prim::Constant[value=1]() %29 : int = aten::add(%25, %28) - %30 : Number = prim::Constant[value=1]() + %30 : int = prim::Constant[value=1]() %y.8 : Dynamic = aten::add(%y.7, %29, %30) %32 : int = prim::Constant[value=1]() %33 : int = aten::add(%29, %32) - %34 : Number = prim::Constant[value=1]() + %34 : int = prim::Constant[value=1]() %y.9 : Dynamic = aten::add(%y.8, %33, %34) %36 : int = prim::Constant[value=1]() %37 : int = aten::add(%33, %36) - %38 : Number = prim::Constant[value=1]() + %38 : int = prim::Constant[value=1]() %y.10 : Dynamic = aten::add(%y.9, %37, %38) return (%y.10); } diff --git a/test/expect/TestScript.test_loop_unrolling_nested.expect b/test/expect/TestScript.test_loop_unrolling_nested.expect index 3b8832d03071a2..cac82d3f3ba210 100644 --- a/test/expect/TestScript.test_loop_unrolling_nested.expect +++ b/test/expect/TestScript.test_loop_unrolling_nested.expect @@ -14,35 +14,35 @@ graph(%x : Dynamic) { %14 : int = aten::sub(%7, %13) %15 : Dynamic, %y.4 : Dynamic = prim::Loop(%11, %8, %9, %6) block0(%j.1 : int, %18 : Dynamic, %19 : Dynamic) { - %20 : Number = prim::Constant[value=1]() + %20 : int = prim::Constant[value=1]() %y.13 : Dynamic = aten::add(%19, %18, %20) %22 : int = prim::Constant[value=1]() %23 : int = aten::add(%18, %22) - %24 : Number = prim::Constant[value=1]() + %24 : int = prim::Constant[value=1]() %y.6 : Dynamic = aten::add(%y.13, %23, %24) %26 : int = prim::Constant[value=1]() %27 : int = aten::add(%23, %26) - %28 : Number = prim::Constant[value=1]() + %28 : int = prim::Constant[value=1]() %y.7 : Dynamic = aten::add(%y.6, %27, %28) %30 : int = prim::Constant[value=1]() %31 : int = aten::add(%27, %30) - %32 : Number = prim::Constant[value=1]() + %32 : int = prim::Constant[value=1]() %y.8 : Dynamic = aten::add(%y.7, %31, %32) %34 : int = prim::Constant[value=1]() %35 : int = aten::add(%31, %34) - %36 : Number = prim::Constant[value=1]() + %36 : int = prim::Constant[value=1]() %y.9 : Dynamic = aten::add(%y.8, %35, %36) %38 : int = prim::Constant[value=1]() %39 : int = aten::add(%35, %38) - %40 : Number = prim::Constant[value=1]() + %40 : int = prim::Constant[value=1]() %y.10 : Dynamic = aten::add(%y.9, %39, %40) %42 : int = prim::Constant[value=1]() %43 : int = aten::add(%39, %42) - %44 : Number = prim::Constant[value=1]() + %44 : int = prim::Constant[value=1]() %y.11 : Dynamic = aten::add(%y.10, %43, %44) %46 : int = prim::Constant[value=1]() %47 : int = aten::add(%43, %46) - %48 : Number = prim::Constant[value=1]() + %48 : int = prim::Constant[value=1]() %y.12 : Dynamic = aten::add(%y.11, %47, %48) %50 : int = prim::Constant[value=1]() %51 : int = prim::Constant[value=1]() @@ -51,7 +51,7 @@ graph(%x : Dynamic) { } %53 : Dynamic, %y.3 : Dynamic = prim::Loop(%14, %8, %15, %y.4) block0(%j : int, %56 : Dynamic, %57 : Dynamic) { - %58 : Number = prim::Constant[value=1]() + %58 : int = prim::Constant[value=1]() %y.5 : Dynamic = aten::add(%57, %56, %58) %60 : int = prim::Constant[value=1]() %61 : int = prim::Constant[value=1]() diff --git a/test/expect/TestScript.test_math_schema.expect b/test/expect/TestScript.test_math_schema.expect index 7d8f8d2800e84c..cff719dabb8ec6 100644 --- a/test/expect/TestScript.test_math_schema.expect +++ b/test/expect/TestScript.test_math_schema.expect @@ -1,5 +1,6 @@ graph(%x : Dynamic %y : Dynamic) { - %2 : Dynamic = aten::add[alpha={1}](%x, %y) - return (%2); + %2 : int = prim::Constant[value=1]() + %3 : Dynamic = aten::add(%x, %y, %2) + return (%3); } diff --git a/test/expect/TestScript.test_math_tensor_number.expect b/test/expect/TestScript.test_math_tensor_number.expect index c0a88913280e59..fb4b81bd00cba5 100644 --- a/test/expect/TestScript.test_math_tensor_number.expect +++ b/test/expect/TestScript.test_math_tensor_number.expect @@ -1,4 +1,6 @@ graph(%x : Dynamic) { - %1 : Dynamic = aten::add[other={7}, alpha={1}](%x) - return (%1); + %1 : int = prim::Constant[value=7]() + %2 : int = prim::Constant[value=1]() + %3 : Dynamic = aten::add(%x, %1, %2) + return (%3); } diff --git a/test/onnx/expect/TestOperators.test_batchnorm_training.expect b/test/onnx/expect/TestOperators.test_batchnorm_training.expect index 9bdadb572b7c03..24cdc2529af7bf 100644 --- a/test/onnx/expect/TestOperators.test_batchnorm_training.expect +++ b/test/onnx/expect/TestOperators.test_batchnorm_training.expect @@ -11,8 +11,8 @@ graph { output: "6" output: "7" output: "8" - output: "batch_norm_dead_output-9" - output: "batch_norm_dead_output-10" + output: "batch_norm_dead_output-13" + output: "batch_norm_dead_output-14" op_type: "BatchNormalization" attribute { name: "epsilon" diff --git a/test/test_jit.py b/test/test_jit.py index 2cccfe97ba0e7c..e7ac44d3c90967 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -880,6 +880,7 @@ def fn(a, b): def test_alexnet(self): x = torch.ones(1, 3, 224, 224) trace, _ = torch.jit.get_trace_graph(torchvision.models.AlexNet(), x) + self.run_pass('cse', trace) self.assertExpectedGraph(trace) # Inplace copies don't work with tracer yet. @@ -1144,20 +1145,20 @@ def tanh(a): def test_batch_elementwise_binary(self): @torch.jit.batch(batch_size=4) - def add(a, b): - return a + b + def mul(a, b): + return a * b xs, batch = self.rand_batch(4, (True, 3), (False, 2)) xs2, batch2 = xs, batch - res_batch = add(batch, batch2) - res = [torch.add(xs[j], xs2[j]) for j in range(4)] + res_batch = mul(batch, batch2) + res = [torch.mul(xs[j], xs2[j]) for j in range(4)] self.assertEqual(res, res_batch.examples()) # test broadcast xs, batch = self.rand_batch(4, (False, 3), (False, 2)) b = torch.rand(3, 2) - res_batch = add(batch, b) - res = [torch.add(xs[j], b) for j in range(4)] + res_batch = mul(batch, b) + res = [torch.mul(xs[j], b) for j in range(4)] self.assertEqual(res, res_batch.examples()) def test_batch_mm(self): @@ -1220,6 +1221,7 @@ def where(c, a, b): res = [torch.where(xs_cond[j], xs[j], xs2[j]) for j in range(4)] self.assertEqual(res, res_batch.examples()) + @unittest.skip("Need support for scalar arguments") def test_lstm_cell(self): def LSTMCell(x, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c): i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index b31cd45ec47a6d..9050009f62c38e 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -134,13 +134,7 @@ PRE_RECORD_TRACE = CodeTemplate("""\ jit::tracer::PreTraceInfo trace_info; if (jit::tracer::isTracing()) { - trace_info = jit::tracer::preRecordTrace( jit::aten::${trace_name}, ${trace_inputs} ); - if (!jit::tracer::ArgumentStash::empty()) { - ${record_positional_attributes} - AT_ASSERT(jit::tracer::ArgumentStash::empty()); - } else { - ${record_attributes} - } + trace_info = jit::tracer::preRecordTrace(jit::aten::${trace_name}, ${trace_inputs}); } """) @@ -387,55 +381,8 @@ def emit_record_trace(env): if not should_trace(declaration): return ('', '') - # Note [clang-802.0.42 tuple overload bug] - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # Originally, my plan for emit_$ecord_trace was to keep it as - # simple as possible, if at the expense of some somewhat ugly - # overloads. So this meant we had a 'recordTrace' function - # with overloads like this: - # - # recordTrace(..., const Variable& out) - # recordTrace(..., const std::tuple& out) - # - # Unfortunately, this triggers a bug in clang-802.0.42 - # (widely used in macOS Sierra 10.12.6) wherein a Variable is - # implicitly convertible into a std::tuple; - # a minimal repro can be seen below here: - # - # #include - # struct T {}; - # void f(const std::tuple&) {} - # void g(T& x) { f(x); } - # - # To work around this bug, the code generator is a bit more - # complicated, and is taught how to handle this situation. - local = {} - - tensor_args = [arg for arg in declaration['arguments'] if arg['simple_type'] in {'Tensor', 'TensorList'}] - local['tensor_args'] = [arg['name'] for arg in tensor_args] - if any(arg['simple_type'] == 'TensorList' for arg in tensor_args): - # Allocate a temporary vector with flatten and pass it in - local['trace_inputs'] = CodeTemplate("flatten_tensor_args( $tensor_args )").substitute(local) - else: - local['trace_inputs'] = CodeTemplate("{ ${tensor_args} }").substitute(local) - - local['record_attributes'] = [] - for arg in declaration['arguments']: - if arg['simple_type'] in {'Tensor', 'TensorList'}: - continue - attr_name = RENAME_ATTRIBUTES.get((declaration['name'], arg['name']), arg['name']) - local['record_attributes'].append(RECORD_ATTRIBUTE.substitute(attr_name=attr_name, name=arg['name'])) - - local['record_positional_attributes'] = [] - for i, arg in enumerate(declaration['arguments']): - if arg['simple_type'] == 'Tensor': - continue - if arg['simple_type'] == 'TensorList': - local['record_positional_attributes'] = POSITIONAL_ATTR_NYI - break - local['record_positional_attributes'].append( - RECORD_POSITIONAL_ATTRIBUTE.substitute(name=arg['name'], i=i)) + local['trace_inputs'] = sum([['"{}"'.format(arg['name']), arg['name']] for arg in declaration['arguments']], []) # Record inplace operations as out-of-place operations (e.g., # not add_ but add) diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index d16c6d303f340b..2f1adf0ab59f4b 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -38,67 +38,6 @@ using namespace at; using namespace torch::autograd::generated; namespace torch { namespace autograd { -// Helper methods for working with Attributes (torch/csrc/jit/attributes.h) - -at::Tensor maybeUnwrapVar(const at::Tensor& t) { - return t.is_variable() ? Variable(t).data() : t; -} - -// The overloaded accessors are convenient for the generated code (since we -// don't want to make the codegen do the dispatch manually) -static void setattr(jit::Node* n, jit::Symbol name, int64_t v) { n->i_(name, v); } -static void setattr(jit::Node* n, jit::Symbol name, const at::Scalar& v) { n->t_(name, maybeUnwrapVar(v.toTensor())); } -static void setattr(jit::Node* n, jit::Symbol name, SparseTensorRef s) { n->t_(name, s.tref); } -static void setattr(jit::Node* n, jit::Symbol name, const at::IntList& v) { n->is_(name, v); } -static void setattr(jit::Node* n, jit::Symbol name, bool v) { n->i_(name, v); } -static void setattr(jit::Node* n, jit::Symbol name, double v) { n->f_(name, v); } -static void setattr(jit::Node* n, jit::Symbol name, std::string v) { n->s_(name, v); } -template -static void setattr(jit::Node* n, jit::Symbol name, std::array v) { n->is_(name, std::vector(v.begin(), v.end())); } - -static jit::Value* insertConstant(jit::Node* n, jit::IValue value) { - jit::WithInsertPoint guard(n); - return insertConstant(*n->owningGraph(), std::move(value)); -} - -static void genericInsertInput(jit::Node* n, size_t idx, jit::IValue value) { - n->insertInput(idx, insertConstant(n, std::move(value))); -} - -void failPositionalAttr() { - throw std::runtime_error("unsupported type in setposattr. File a bug report!"); -} - -static void setposattr(jit::Node* n, size_t idx, const char *name, int64_t v) { genericInsertInput(n, idx, v); } -static void setposattr(jit::Node* n, size_t idx, const char *name, const at::Scalar& v) { genericInsertInput(n, idx, v); } -static void setposattr(jit::Node* n, size_t idx, const char *name, SparseTensorRef s) { failPositionalAttr(); } -static void setposattr(jit::Node* n, size_t idx, const char *name, const at::IntList& v) { - using ArgumentStash = jit::tracer::ArgumentStash; - if (ArgumentStash::hasIntList(name)) { - auto info = ArgumentStash::popIntList(name); - for (size_t i = 0; i < info.size(); ++i) { - if (info[i] != nullptr) continue; - info[i] = insertConstant(n, v[i]); - } - for (jit::Value* v : info) { - if (*v->type() != *jit::IntType::get()) { - throw std::runtime_error( - "Type mismatch in setposattr for IntList. Check that your program " - "is valid without tracing, and please file a bug report if it is."); - } - } - jit::WithInsertPoint insert_point{n}; - auto& g = *n->owningGraph(); - auto size = g.insertNode(g.createList(jit::IntType::get(), info))->output(); - n->insertInput(idx, size); - } else { - return genericInsertInput(n, idx, v); - } -} -static void setposattr(jit::Node* n, size_t idx, const char *name, bool v) { genericInsertInput(n, idx, v); } -static void setposattr(jit::Node* n, size_t idx, const char *name, double v) { genericInsertInput(n, idx, v); } -template -static void setposattr(jit::Node* n, size_t idx, const char *name, std::array v) { failPositionalAttr(); } VariableType::VariableType(Context* context, Type* baseType) : Type(context, /*is_variable=*/true, /*is_undefined=*/false) diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py index 18c043a6c1061d..ad9ad2e05c4f4c 100644 --- a/tools/jit/gen_jit_dispatch.py +++ b/tools/jit/gen_jit_dispatch.py @@ -84,7 +84,7 @@ def from_attribute(arg): 'Scalar': '{}.toScalar()', 'ScalarType': 'static_cast({}.toInt())', 'Tensor': '{}.toTensor()', - 'bool': '{}.toInt()', + 'bool': 'bool({}.toInt())', 'double': '{}.toDouble()', 'int64_t': '{}.toInt()', 'std::array': 'as_bool_array<2>({}.toIntList()->elements())', diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index f3e52c0171b121..72d51bc2f304b9 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -1,7 +1,9 @@ #include "torch/csrc/jit/autodiff.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" +#include "torch/csrc/jit/passes/common_subexpression_elimination.h" #include "torch/csrc/jit/symbolic_variable.h" +#include "torch/csrc/jit/operator.h" #include "torch/csrc/utils/functional.h" #include @@ -13,36 +15,66 @@ namespace torch { namespace jit { using value_map = std::unordered_map; using value_set = std::unordered_set; -bool hasOneValuedInput(Node *n, torch::jit::Symbol name) { - auto maybe_t = n->get(name); - if (!maybe_t) return false; - return maybe_t->toDouble() == 1.0; +void wrapDim(int64_t & dim, const std::vector & sizes) { + if (dim < 0) { + dim += sizes.size(); + } } bool isDifferentiable(Node * n) { - static std::unordered_set differentiable_kinds = { - aten::add, aten::sub, aten::mul, prim::Constant, - aten::sigmoid, aten::tanh, aten::mm, aten::chunk, aten::split, aten::t, aten::neg, - aten::unsqueeze, aten::expand, aten::addmm, aten::gt, aten::lt, aten::eq, aten::ne, aten::ge, aten::le, aten::type_as, - aten::relu, aten::exp, prim::AutogradAdd + static OperatorSet differentiable_ops = { + "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", + "aten::add(Tensor self, Scalar other, *, Scalar alpha) -> Tensor", + "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", + "aten::sub(Tensor self, Scalar other, *, Scalar alpha) -> Tensor", + "aten::mul(Tensor self, Tensor other) -> Tensor", + "aten::mul(Tensor self, Scalar other) -> Tensor", + "aten::sigmoid(Tensor self) -> Tensor", + "aten::tanh(Tensor self) -> Tensor", + "aten::relu(Tensor self) -> Tensor", + "aten::exp(Tensor self) -> Tensor", + "aten::t(Tensor self) -> Tensor", + "aten::neg(Tensor self) -> Tensor", + "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]", + "aten::split(Tensor self, int split_size, int dim) -> Tensor[]", + "aten::type_as(Tensor self, Tensor other) -> Tensor", + "aten::unsqueeze(Tensor self, int dim) -> Tensor", + "aten::mm(Tensor self, Tensor mat2) -> Tensor", + "aten::lt(Tensor self, Tensor other) -> Tensor", + "aten::le(Tensor self, Tensor other) -> Tensor", + "aten::gt(Tensor self, Tensor other) -> Tensor", + "aten::ge(Tensor self, Tensor other) -> Tensor", + "aten::eq(Tensor self, Tensor other) -> Tensor", + "aten::ne(Tensor self, Tensor other) -> Tensor" }; - // TODO: check this more generally via schema - // This check ensures that the `alpha` and `beta` attributes on this addmm - // node are constant and equivalent to 1.0 - if (n->kind() == aten::addmm) { - if (n->inputs().size() > 3) - return false; - if (!hasOneValuedInput(n, attr::alpha) || !hasOneValuedInput(n, attr::beta)) - return false; - } - auto isTensor = [](Value* v) { return v->type()->isSubtypeOf(DynamicType::get()); }; - if(!std::all_of(n->inputs().begin(), n->inputs().end(), isTensor) - || !std::all_of(n->outputs().begin(), n->outputs().end(), isTensor)) - return false; + if (n->kind() == prim::Constant || n->kind() == prim::AutogradAdd) + return true; + if (differentiable_ops.find(n)) + return true; - if (n->kind() == aten::type_as && !n->inputs().at(1)->isTensor()) { - return false; + if (n->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) { + return static_cast(n->input(1)->type()->cast()); + } + if (n->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor")) { + if (!n->is_constant(attr::dim)) return false; + for (Value * input : n->inputs().slice(0, n->inputs().size() - 1)) { + if (!input->type()->cast()) return false; + } + return true; + } + if (n->matches("aten::squeeze(Tensor self) -> Tensor")) { + return static_cast(n->input()->type()->cast()); + } + if (n->matches("aten::squeeze(Tensor self, int dim) -> Tensor")) { + return n->namedInput(attr::self)->type()->cast() && n->is_constant(attr::dim); + } + if (n->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor")) { + return n->is_constant(attr::size) && n->is_constant(attr::implicit); + } + if (n->matches("aten::view(Tensor self, int[] size) -> Tensor") || + n->matches("aten::reshape(Tensor self, int[] shape) -> Tensor")) { + return static_cast(n->namedInput(attr::self)->type()->cast()); } // linear blocks may appear as inputs to graph executors, but they are removed @@ -55,7 +87,7 @@ bool isDifferentiable(Node * n) { static_cast(isDifferentiable)); } - return differentiable_kinds.count(n->kind()) > 0; + return false; } @@ -83,146 +115,149 @@ bool outputRequiresGrad(Node* node, std::function requires_grad) { } } - - static std::vector gradientForNode(Node* node, ArrayRef grad_values) { const auto build_sym_grad = [node](const std::vector& grads) -> std::vector { auto inputs = fmap(node->inputs()); auto outputs = fmap(node->outputs()); - switch(node->kind()) { - case aten::add: - // TODO (apaszke): remove formulas for attributed nodes once they are removed - // o = self + alpha*other - if(inputs.size() == 1) { - return { grads.at(0) }; - } else if (node->hasAttribute(attr::alpha)) { - return {grads.at(0), grads.at(0) * at::Scalar(node->t(attr::alpha))}; - } else { - return {grads.at(0), nullptr, grads.at(0) * node->namedInput(attr::alpha)}; - } - case aten::sub: - // o = self - alpha*other - if(inputs.size() == 1) { - return {grads.at(0)}; - } else if (node->hasAttribute(attr::alpha)) { - return {grads.at(0), -grads.at(0) * at::Scalar(node->t(attr::alpha))}; - } else { - return {grads.at(0), nullptr, grads.at(0) * node->namedInput(attr::alpha)}; - } - case aten::mul: - // o = self * other - if(inputs.size() == 1) - return {grads.at(0) * at::Scalar(node->t(attr::other))}; - else - return {grads.at(0) * inputs.at(1), grads.at(0) * inputs.at(0)}; - case prim::Constant: - return {}; - case aten::sigmoid: - return {grads.at(0) * outputs.at(0) * (1 - outputs.at(0))}; - case aten::tanh: - return {grads.at(0) * (1 - outputs.at(0) * outputs.at(0))}; - case aten::relu: - return {grads.at(0) * (outputs.at(0) > at::Scalar(0)).type_as(outputs.at(0))}; - case aten::exp: - return {grads.at(0) * (outputs.at(0))}; - case aten::chunk: - case aten::split: - return {SymbolicVariable::cat(grads, node->namedInput(attr::dim))}; - case aten::t: - return {grads.at(0).t()}; - case aten::neg: - return {-grads.at(0)}; - case aten::view: - // TODO: if sizes are not available statically, add an operator that reutrns them as a tuple - return {grads.at(0).view(inputs.at(0).sizes())}; - case aten::type_as: - return {grads.at(0).type_as(inputs.at(0))}; - case aten::unsqueeze: - return {grads.at(0).squeeze(node->namedInput(attr::dim))}; - case aten::mm: { - SymbolicVariable dmat1, dmat2; - if (auto type = inputs.at(0).value()->type()->cast()) { - auto sizes = type->sizes(), strides = type->strides(); - if (strides.at(0) == 1 && strides.at(1) == sizes.at(0)) { - dmat1 = inputs.at(1).mm(grads.at(0).t()).t(); - } else { - dmat1 = grads.at(0).mm(inputs.at(1).t()); - } + + if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") || + node->matches("aten::add(Tensor self, Scalar other, *, Scalar alpha) -> Tensor")) { + return {grads.at(0), grads.at(0) * node->namedInput(attr::alpha), nullptr}; + + } else if (node->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") || + node->matches("aten::sub(Tensor self, Scalar other, *, Scalar alpha) -> Tensor")) { + return {grads.at(0), -grads.at(0) * node->namedInput(attr::alpha), nullptr}; + + } else if (node->matches("aten::mul(Tensor self, Tensor other) -> Tensor") || + node->matches("aten::mul(Tensor self, Scalar other) -> Tensor")) { + return {grads.at(0) * inputs.at(1), grads.at(0) * inputs.at(0)}; + + } else if (node->matches("aten::sigmoid(Tensor self) -> Tensor")) { + return {grads.at(0) * outputs.at(0) * (1 - outputs.at(0))}; + + } else if (node->matches("aten::tanh(Tensor self) -> Tensor")) { + return {grads.at(0) * (1 - outputs.at(0) * outputs.at(0))}; + + } else if (node->matches("aten::relu(Tensor self) -> Tensor")) { + return {grads.at(0) * (outputs.at(0) > at::Scalar(0)).type_as(outputs.at(0))}; + + } else if (node->matches("aten::exp(Tensor self) -> Tensor")) { + return {grads.at(0) * (outputs.at(0))}; + + } else if (node->matches("aten::t(Tensor self) -> Tensor")) { + return {grads.at(0).t()}; + + } else if (node->matches("aten::neg(Tensor self) -> Tensor")) { + return {-grads.at(0)}; + + } else if (node->matches("aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]") || + node->matches("aten::split(Tensor self, int split_size, int dim) -> Tensor[]")) { + return {SymbolicVariable::cat(grads, node->namedInput(attr::dim)), nullptr, nullptr}; + + } else if (node->matches("aten::view(Tensor self, int[] size) -> Tensor") || + node->matches("aten::reshape(Tensor self, int[] shape) -> Tensor")) { + // TODO: if sizes are not available statically, add an operator that reutrns them as a tuple + auto sizes = node->namedInput(attr::self)->type()->expect()->sizes(); + return {grads.at(0).reshape(sizes), nullptr}; + + } else if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) { + return {grads.at(0).type_as(inputs.at(0)), nullptr}; + + } else if (node->matches("aten::unsqueeze(Tensor self, int dim) -> Tensor")) { + return {grads.at(0).squeeze(node->namedInput(attr::dim)), nullptr}; + + } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) { + SymbolicVariable dmat1, dmat2; + if (auto type = inputs.at(0).value()->type()->cast()) { + auto sizes = type->sizes(), strides = type->strides(); + if (strides.at(0) == 1 && strides.at(1) == sizes.at(0)) { + dmat1 = inputs.at(1).mm(grads.at(0).t()).t(); } else { dmat1 = grads.at(0).mm(inputs.at(1).t()); } - if (auto type = inputs.at(1).value()->type()->cast()) { - auto sizes = type->sizes(), strides = type->strides(); - if (strides.at(0) == 1 && strides.at(1) == sizes.at(0)) { - dmat2 = grads.at(0).t().mm(inputs.at(0)).t(); - } else { - dmat2 = inputs.at(0).t().mm(grads.at(0)); - } + } else { + dmat1 = grads.at(0).mm(inputs.at(1).t()); + } + if (auto type = inputs.at(1).value()->type()->cast()) { + auto sizes = type->sizes(), strides = type->strides(); + if (strides.at(0) == 1 && strides.at(1) == sizes.at(0)) { + dmat2 = grads.at(0).t().mm(inputs.at(0)).t(); } else { dmat2 = inputs.at(0).t().mm(grads.at(0)); } - return {dmat1, dmat2}; + } else { + dmat2 = inputs.at(0).t().mm(grads.at(0)); } - case aten::expand: { - const auto& input_sizes = inputs.at(0).sizes(); - if (input_sizes.size() == 0) - return {grads.at(0).sum()}; - auto grad_sizes = node->get>(attr::size).value(); - auto grad = grads.at(0); - while (grad_sizes.size() > input_sizes.size()) { - grad = grad.sum(0, false); - grad_sizes.erase(grad_sizes.begin()); - } - for (size_t i = 0; i < input_sizes.size(); ++i) { - if (input_sizes[i] == 1 && grad_sizes[i] > 1) { - grad = grad.sum(i, true); - } - } - return {grad}; + return {dmat1, dmat2}; + + } else if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor")) { + const auto& input_sizes = inputs.at(0).sizes(); + if (input_sizes.size() == 0) + return {grads.at(0).sum(), nullptr, nullptr}; + auto grad_sizes = node->get>(attr::size).value(); + auto grad = grads.at(0); + while (grad_sizes.size() > input_sizes.size()) { + grad = grad.sum(0, false); + grad_sizes.erase(grad_sizes.begin()); } - case aten::squeeze: { - const auto& sizes = inputs.at(0).sizes(); - // TODO (apaszke): need to select the right overload here - if (node->hasAttribute(attr::dim)) { - int dim = node->i(attr::dim); - return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim)}; - } else { - std::vector squeezed_dims; - for (size_t i = 0; i < sizes.size(); ++i) { - if (sizes[i] != 1) continue; - squeezed_dims.push_back(i); - } - SymbolicVariable returned_grad = grads.at(0); - for (auto it = squeezed_dims.rbegin(); it != squeezed_dims.rend(); ++it) - returned_grad = returned_grad.unsqueeze(*it); - return {returned_grad}; + for (size_t i = 0; i < input_sizes.size(); ++i) { + if (input_sizes[i] == 1 && grad_sizes[i] > 1) { + grad = grad.sum(i, true); } } - case aten::cat: { - int dim = node->get(attr::dim).value(); - const auto& first_sizes = inputs.at(0).sizes(); - const auto has_first_sizes = [&first_sizes](SymbolicVariable var) { - return var.sizes() == first_sizes; - }; - // TODO (apaszke): This will need an adjustment for the dim argument - // NB: this is a specialization for the common case where all inputs are - // of equal sizes. We can use a single split operation to handle that. - if (std::all_of(inputs.begin(), inputs.end(), has_first_sizes)) { - return grads.at(0).chunk(inputs.size(), dim); - } else { - size_t offset = 0; - auto grad = grads.at(0); - std::vector returned_grads; - for (auto input : inputs) { - returned_grads.push_back(grad.narrow(dim, offset, input.sizes()[dim])); - offset += input.sizes()[dim]; - } - return returned_grads; + return {grad, nullptr, nullptr}; + + } else if (node->matches("aten::squeeze(Tensor self) -> Tensor")) { + const auto& sizes = inputs.at(0).sizes(); + std::vector squeezed_dims; + for (size_t i = 0; i < sizes.size(); ++i) { + if (sizes[i] != 1) continue; + squeezed_dims.push_back(i); + } + SymbolicVariable returned_grad = grads.at(0); + for (auto it = squeezed_dims.begin(); it != squeezed_dims.end(); ++it) + returned_grad = returned_grad.unsqueeze(*it); + return {returned_grad}; + + } else if (node->matches("aten::squeeze(Tensor self, int dim) -> Tensor", /*const=*/attr::dim)) { + int64_t dim = *node->get(attr::dim); + const auto& sizes = inputs.at(0).sizes(); + wrapDim(dim, sizes); + if (sizes.size() == 0) { + return {grads.at(0), nullptr}; + } + return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim), nullptr}; + + } else if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor", /*const=*/attr::dim)) { + int dim = *node->get(attr::dim); + auto tensor_inputs = inputs; tensor_inputs.pop_back(); + const auto& first_sizes = tensor_inputs.at(0).sizes(); + const auto has_first_sizes = [&first_sizes](SymbolicVariable var) { + return var.sizes() == first_sizes; + }; + + // NB: this is a specialization for the common case where all inputs are + // of equal sizes. We can use a single split operation to handle that. + if (std::all_of(tensor_inputs.begin(), tensor_inputs.end(), has_first_sizes)) { + auto tensor_grads = grads.at(0).chunk(tensor_inputs.size(), dim); + tensor_grads.push_back(nullptr); // for attr::dim + return tensor_grads; + } else { + size_t offset = 0; + auto grad = grads.at(0); + std::vector tensor_grads; + for (auto input : tensor_inputs) { + tensor_grads.push_back(grad.narrow(dim, offset, input.sizes()[dim])); + offset += input.sizes()[dim]; } + tensor_grads.push_back(nullptr); // for attr::dim + return tensor_grads; } + + } else if (node->kind() == prim::Constant) { + return {}; } - throw std::runtime_error(std::string("don't support differentiation of `") + - node->kind().toDisplayString() + "`"); + throw std::runtime_error(std::string("failed to differentiate `") + node->kind().toDisplayString() + "`"); }; if (!isDifferentiable(node)) { throw std::runtime_error(std::string("differentiation of ") + node->kind().toDisplayString() + " " @@ -273,15 +308,13 @@ static std::vector linearGradientForNode(Node* node, ArrayRef gr // to make reading gradient graphs easier, remember the name of the forward op linear->s_(attr::name, node->kind().toDisplayString()); auto block = linear->addBlock(); - { - WithInsertPoint guard(block); - auto results = gradientForNode(node, grad_values); - for(auto r : results) { - block->registerOutput(r); - linear->addOutput()->copyMetadata(r); - } - } - return linear->outputs(); + WithInsertPoint guard(block); + auto results = gradientForNode(node, grad_values); + return fmap(results, [block, linear](Value *grad) -> Value* { + if (!grad) return nullptr; + block->registerOutput(grad); + return linear->addOutput()->copyMetadata(grad); + }); } struct ReverseDetails { @@ -377,6 +410,40 @@ static ReverseDetails addReverseInline(Gradient& grad_desc, return ReverseDetails(std::move(grad_map), std::move(requires_grad_set), reverse_block); } +// Any temporary value from the primal graphs needs to be captured for later use in the +// reverse graph, to avoid costly recomputations. However, a lot of the nodes we have +// in our graphs are simply constants, which are cheap to execute and replicate, and so +// it's better to just copy them into the reverse graph, without polluting the output +// lists unnecessarily. +static void liftConstants(Gradient& grad_desc, ReverseDetails& rev_info) { + static const auto err = [](Value*) -> Value* { + throw std::runtime_error("unexpected input"); + }; + auto & graph = *grad_desc.f; + Block* reverse_block = rev_info.reverse_block; + + for (Node *top_node : reverse_block->nodes()) { + JIT_ASSERT(top_node->kind() == prim::GradOf || + top_node->kind() == prim::AutogradAdd || + top_node->kind() == prim::Undefined); + if (top_node->kind() != prim::GradOf) continue; + Block * grad_body = top_node->blocks().at(0); + for (Node *node : grad_body->nodes()) { + for (Value * input : node->inputs()) { + if (input->node()->kind() != prim::Constant) continue; + if (input->node()->owningBlock() == grad_body) continue; + Node *lifted_constant = graph.createClone(input->node(), err); + reverse_block->prependNode(lifted_constant); + node->replaceInputWith(input, lifted_constant->output()); + } + } + } + + // It's possible the we've cloned the same constants many times, + // so we use CSE to deduplicate them. + EliminateCommonSubexpression(reverse_block); +} + // Takes a grad_desc.f returned from `addReverseInline` and splits off the // reverse_block into its own graph, storing it in df. // All intermediates needed in the second stage are added to @@ -516,6 +583,8 @@ Gradient differentiate(std::shared_ptr& _graph, const std::vector& WithInsertPoint guard(grad_desc.f->block()); // Fills in df_input_vjps and df_output_vjps auto rev_info = addReverseInline(grad_desc, requires_grad); + // Lift constants captured for the reverse graph into it + liftConstants(grad_desc, rev_info); // addReverseInline has to call gradientForNode if *any* of the outputs // require grad, but it will emit vjps for *all* outputs. Use DCE to remove // unnecessary nodes. diff --git a/torch/csrc/jit/fusion_compiler.cpp b/torch/csrc/jit/fusion_compiler.cpp index 7e0db38e5b5614..8d20045efefe6a 100644 --- a/torch/csrc/jit/fusion_compiler.cpp +++ b/torch/csrc/jit/fusion_compiler.cpp @@ -4,6 +4,7 @@ #include "torch/csrc/jit/ir.h" #include "torch/csrc/jit/code_template.h" #include "torch/csrc/jit/resource_guard.h" +#include "torch/csrc/jit/constants.h" #include "torch/csrc/utils/disallow_copy.h" #include "torch/csrc/variable_tensor_functions.h" @@ -196,15 +197,14 @@ static std::string valueName(Value * n) { return "n" + std::to_string(n->unique()); } -static std::string scalarValue(const at::Tensor & t) { - auto s = at::Scalar(t); - if (s.isIntegral()){ - return std::to_string(s.toLong()); - } else { - std::ostringstream out; - out << std::scientific << s.toDouble() << "f"; - return out.str(); - } +static std::string scalarValue(int64_t v) { + return std::to_string(v); +} + +static std::string scalarValue(double v) { + std::ostringstream out; + out << std::scientific << v << "f"; + return out.str(); } static const char * scalarTypeName(at::ScalarType type) { @@ -280,42 +280,31 @@ std::string encodeRHS(Node * n) { {aten::remainder, "remainderf(${0}, ${1})"}, {aten::pow, "powf(${0}, ${1})"}, - //alpha - {aten::add, "${0} + ${alpha}*${1}"}, - {aten::sub, "(${0} - ${alpha}*${1})"}, - - // special - {aten::lerp, "${0} + ${weight}*(${1} - ${0})"}, - {aten::clamp, "min(max(${0},${min}),${max})"}, + // binary with alpha + {aten::add, "${0} + ${2}*${1}"}, + {aten::sub, "(${0} - ${2}*${1})"}, // simple derivatives {aten::_sigmoid_backward, "${0} * ${1} * (1.f - ${1})"}, {aten::_tanh_backward, "${0} * (1.f - ${1} * ${1})"}, }; + if (n->kind() == prim::Constant) { + auto val = toIValue(n->output()).value(); + if (val.isDouble()) { + return scalarValue(val.toDouble()); + } else { + JIT_ASSERT(val.isInt()); + return scalarValue(val.toInt()); + } + } TemplateEnv env; size_t i = 0; for(auto in : n->inputs()) { env.s(std::to_string(i++), valueName(in)); } - // TODO (apaszke): remove once we get rid of attributes - // ops like div have a / b or a / 2 with the constant having the attribute other - // so we add other as an input if it is present - // 'pow' is the same but uses exponent as the attribute, so we handle that here as well - if(n->hasAttribute(attr::other) || n->hasAttribute(attr::exponent)) { - env.s(std::to_string(i), scalarValue(n->t(attr::other))); - } - // we also add any other scalar tensors to the env for special ops - for(auto a : n->attributeNames()) { - if(n->kindOf(a) == AttributeKind::t) { - auto v = n->t(a); - if(v.dim() == 0) { - JIT_ASSERT(a.is_attr()); - env.s(a.toUnqualString(), scalarValue(v)); - } - } - } + const auto & str = simple_map_ops.at(n->kind()); return format(str, env); } @@ -362,9 +351,12 @@ std::vector emitCompilationUnit(std::ostream & out, flat_output_nodes.push_back(o); } else { auto cat = o->node(); - size_t nInputs = cat->inputs().size(); + auto tensor_inputs = cat->inputs(); + // We need to drop the dim arg + tensor_inputs = tensor_inputs.slice(0, tensor_inputs.size() - 1); + size_t nInputs = tensor_inputs.size(); concat_desc.emplace_back(desc, nInputs, cat->get(attr::dim).value()); - for(auto c : cat->inputs()) { + for(auto c : tensor_inputs) { emitFormal(c, *concat_desc.back().subtensorDesc); flat_output_nodes.push_back(c); } diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index 2c595ffd679c27..df81c378ad137d 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -388,8 +388,10 @@ struct GraphExecutorImpl { auto graph_ = graph->copy(); runRequiredPasses(graph_); if(optimize) { - if(!symbolically_differentiable) + if(!symbolically_differentiable) { + EraseShapeInformation(*graph_); CreateAutodiffSubgraphs(*graph_); + } runOptimization(graph_, /*graphMustSupportVariables=*/true); } autograd_fallback_graph = graph_; diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp index 2d3af265a5d651..d54b0434e7e64a 100644 --- a/torch/csrc/jit/import.cpp +++ b/torch/csrc/jit/import.cpp @@ -512,8 +512,46 @@ std::shared_ptr buildGraph(const Graph_& graph_, std::vector& return graph; } +// TODO: this should be removed once we'll be able to serialize value types +void reconstructOutputTypes(Block *b) { + for (Node * n : b->nodes()) { + if (n->kind() == prim::Constant) { + switch (n->kindOf(attr::value)) { + case AttributeKind::i: + n->output()->setType(IntType::get()); + break; + case AttributeKind::f: + n->output()->setType(FloatType::get()); + break; + case AttributeKind::is: + n->output()->setType(ListType::ofInts()); + break; + case AttributeKind::t: + n->output()->setType(DynamicType::get()); + break; + default: + throw std::runtime_error("Unsupported case in reconstructOutputTypes. File a bug report"); + } + } else if (n->kind() == prim::ListConstruct && n->inputs().size() > 0) { + auto input_types = fmap(n->inputs(), [](Value *v) -> TypePtr { + return v->node()->kind() == prim::Constant ? v->type() : nullptr; + }); + // Check that all types are equal + if (std::equal(std::next(input_types.begin()), input_types.end(), input_types.begin())) { + auto elem_type = input_types[0]; + if (elem_type == IntType::get()) { + n->output()->setType(ListType::ofInts()); + } + } + } + for (Block * b : n->blocks()) { + reconstructOutputTypes(b); + } + } } +} // anonymous namespace + std::shared_ptr ImportIRGraph(const std::string& serialized_graph, std::vector& initializers) { @@ -523,6 +561,8 @@ std::shared_ptr ImportIRGraph(const std::string& serialized_graph, auto graph = buildGraph(model.graph, initializers); + reconstructOutputTypes(graph->block()); + return graph; } diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index cf7dda32413c23..65bdcf695f6de2 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -147,6 +147,7 @@ static std::vector> flattenStages(Graph & graph) { while(input_pos < graph.inputs().size() && graph.inputs()[input_pos]->stage() == i) { auto nv = store->addOutput(); auto old_node = graph.inputs()[input_pos]; + nv->setType(old_node->type()); stage_input_types[i].push_back(old_node->type()); old_node->replaceAllUsesWith(nv); input_pos++; diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 6edf2bc176e364..7f09b22b324d11 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -619,7 +619,23 @@ Value* Node::namedInput(Symbol name) const { // so this is completely unsafe and needs to be gone as soon as possible. return v; } - return input(findArgument(schema(), name).first); + const auto & the_schema = schema(); + int64_t tensor_list_pos = 0; + for (auto & arg : the_schema.arguments) { + if (*arg.type == *ListType::ofTensors()) + break; + tensor_list_pos++; + } + int64_t arg_pos = findArgument(schema(), name).first; + // XXX: we don't have a single value we could give for a Tensor[], + // because we flatten lists into arguments + JIT_ASSERT(arg_pos != tensor_list_pos); + // NB: if there's no tensor list, then tensor_list_pos == arguments.size(), so this is always true + if (arg_pos < tensor_list_pos) { + return input(arg_pos); + } else { + return input(inputs().size() - (the_schema.arguments.size() - arg_pos)); + } } bool Node::matches(const char *signature_literal, at::ArrayRef const_inputs) { diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp index 560239948325a3..19da2195e5b33d 100644 --- a/torch/csrc/jit/operator.cpp +++ b/torch/csrc/jit/operator.cpp @@ -210,8 +210,8 @@ struct SchemaParser { Lexer L; bool kwarg_only; }; -} +} // namespace script namespace { @@ -271,7 +271,7 @@ struct OperatorRegistry { operators_by_sig[canonicalSchemaString(op.schema)] = op_ptr; } - Operator& lookupByLiteral(const char * name) { + const std::shared_ptr& lookupByLiteral(const char * name) { auto it = operators_by_sig_literal.find(name); if (it == operators_by_sig_literal.end()) { auto op_ptr_it = operators_by_sig.find(name); @@ -286,7 +286,7 @@ struct OperatorRegistry { JIT_ASSERTM(op_ptr_it != operators_by_sig.end(), "Couldn't find an operator for %s", name); it = operators_by_sig_literal.emplace_hint(it, name, op_ptr_it->second); } - return *it->second; + return it->second; } const std::vector>& getOperators(Symbol name) { @@ -315,7 +315,7 @@ const std::vector>& getAllOperatorsFor(Symbol name) { } Operator& sig(const char *signature) { - return getRegistry().lookupByLiteral(signature); + return *getRegistry().lookupByLiteral(signature); } FunctionSchema parseSchema(const std::string& schema) { @@ -431,4 +431,26 @@ const Operator& getOperatorFor(const Node* node) { throw er; } + +OperatorSet::OperatorSet(std::initializer_list sig_literals) { + auto & registry = getRegistry(); + for (const char * sig : sig_literals) { + auto op = registry.lookupByLiteral(sig); + ops[Symbol::fromQualString(op->schema.name)].push_back(op); + } +} + +Operator* OperatorSet::find(Node *n) { + auto it = ops.find(n->kind()); + if (it == ops.end()) { + return nullptr; + } + for (auto & op : it->second) { + if (op->matches(n)) { + return op.get(); + } + } + return nullptr; +} + }} diff --git a/torch/csrc/jit/operator.h b/torch/csrc/jit/operator.h index 47ed788770f1cb..7e6a314d2cb8c3 100644 --- a/torch/csrc/jit/operator.h +++ b/torch/csrc/jit/operator.h @@ -38,6 +38,7 @@ struct TORCH_API Operator { // as attributes or inputs. This function returns the right Operation function, // given a node encoded for one variant. // Behavior is undefined if matches(n) == false + // TODO (apaszke) : remove Operation selectVariant(Node* n) const { if(n->hasAttributes()) { JIT_ASSERT(op_const_attributes != nullptr); @@ -77,4 +78,13 @@ struct TORCH_API RegisterOperators { } }; +struct OperatorSet { + OperatorSet(std::initializer_list sig_literals); + // XXX: Returns a nullptr if no Operator in the set matches n + Operator* find(Node *n); +private: + std::unordered_map>> ops; +}; + + }} diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.h b/torch/csrc/jit/passes/common_subexpression_elimination.h index 64ae4f6bd9ca8b..f74f0868eb7a88 100644 --- a/torch/csrc/jit/passes/common_subexpression_elimination.h +++ b/torch/csrc/jit/passes/common_subexpression_elimination.h @@ -5,5 +5,6 @@ namespace torch { namespace jit { TORCH_API void EliminateCommonSubexpression(std::shared_ptr& graph); +TORCH_API void EliminateCommonSubexpression(Block * block); }} diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 745e910ccc9f9a..e0395ffdaadeae 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -32,7 +32,6 @@ std::unordered_set simple_mappable = { aten::atan, aten::atan2, aten::ceil, - aten::clamp, aten::cos, aten::cosh, aten::div, @@ -45,7 +44,6 @@ std::unordered_set simple_mappable = { aten::ge, aten::gt, aten::le, - aten::lerp, aten::lgamma, aten::log, aten::log10, @@ -74,33 +72,17 @@ std::unordered_set simple_mappable = { aten::type_as, aten::_sigmoid_backward, aten::_tanh_backward, + // TODO support those + //aten::clamp, + //aten::lerp, }; bool isSimpleMap(Node *node) { + // TODO: use signature matching if(simple_mappable.count(node->kind()) == 0) return false; if((node->kind() == aten::min || node->kind() == aten::max) && node->inputs().size() == 1) return false; - // Make sure that the node doesn't broadcast. - JIT_ASSERT(node->inputs().size() > 0); - TensorTypePtr expected_type = node->inputs()[0]->type()->cast(); - if (!expected_type) return false; -//type checking is intentionally dropped from isSimpleMap -//isFusable is checking input/output types as there are some exceptions from allFloatIO requirement - static const auto equal_modulo_strides = [](const TensorTypePtr& expected, const TypePtr& _actual) { - TensorTypePtr actual = _actual->cast(); - return actual && - expected->device() == actual->device() && - expected->sizes() == actual->sizes(); - }; - for (Value * val : node->inputs()) { - if (!equal_modulo_strides(expected_type, val->type())) - return false; - } - for (Value * val : node->outputs()) { - if (!equal_modulo_strides(expected_type, val->type())) - return false; - } return true; } @@ -133,10 +115,8 @@ struct GraphFuser { bool hasSupportedType(Value* node) { if (auto tt = node->type()->cast()) { if (tt->scalarType() == at::kFloat) return true; - #ifdef USE_CUDA // Checks for half tensor on GPU - // const auto device = tt->device(); if (tt->device() != kCPUDevice && CUDA_VERSION >= 9 && tt->scalarType() == at::ScalarType::Half) { @@ -144,57 +124,74 @@ struct GraphFuser { } #endif } - return false; } - bool allSupportedList(at::ArrayRef list){ - for (auto& o: list){ - if (!hasSupportedType(o)) return false; + bool areTensorsOfSameShape(at::ArrayRef values) { + auto expected_type = values.at(0)->type()->cast(); + if (!expected_type) return false; + for (Value * val : values) { + auto val_type = val->type()->cast(); + if (!val_type) return false; + if (expected_type->device() != val_type->device()) return false; + if (expected_type->sizes() != val_type->sizes()) return false; } - return true; } - bool allSupportedIO(Node* node) { - return (allSupportedList(node->inputs()) && allSupportedList(node->outputs())); + bool hasSupportedType(Node* node) { + return areTensorsOfSameShape(node->inputs()) && + haveSupportedType(node->inputs()) && + haveSupportedType(node->outputs()); + } + + bool haveSupportedType(at::ArrayRef list) { + for (Value *v : list) { + if (!hasSupportedType(v)) return false; + } + return true; } + bool isFusable(Node * node) { if (node->owningBlock() != block) return false; if (node->kind() == prim::FusionGroup) return true; if (!isSimpleMap(node)) return false; - switch (node->kind()){ -//comparison operators produce Byte type, and it's ok, check only inputs - case aten::le: - case aten::ge: - case aten::lt: - case aten::gt: - case aten::ne: - case aten::eq: - return allSupportedList(node->inputs()); - case aten::type_as: -//type_as can have different input types as long as output is float, check only output - return allSupportedList(node->outputs()); - default: - return allSupportedIO(node); + + if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", /*const=*/attr::alpha)) { + std::vector inputs {node->namedInput(attr::self), node->namedInput(attr::other)}; + return areTensorsOfSameShape(inputs) && haveSupportedType(inputs); + } else if (node->matches("aten::add(Tensor self, Scalar other, *, Scalar alpha) -> Tensor", /*const=*/{attr::other, attr::alpha})) { + return hasSupportedType(node->namedInput(attr::self)); + } else if (node->matches("aten::lt(Tensor self, Tensor other) -> Tensor") || + node->matches("aten::le(Tensor self, Tensor other) -> Tensor") || + node->matches("aten::gt(Tensor self, Tensor other) -> Tensor") || + node->matches("aten::ge(Tensor self, Tensor other) -> Tensor") || + node->matches("aten::eq(Tensor self, Tensor other) -> Tensor") || + node->matches("aten::ne(Tensor self, Tensor other) -> Tensor")) { + // comparison operators produce Byte type, and it's ok, check only inputs + return areTensorsOfSameShape(node->inputs()) && haveSupportedType(node->inputs()); + } else if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) { + // type_as can have different input types as long as output is float, check only output + return haveSupportedType(node->outputs()); + } else { + return hasSupportedType(node); } } - bool allOutputsHaveSameSize(Node * node) { - TensorTypePtr tt_ptr = nullptr; - for (const auto i : node->inputs()) { - auto cur_tt_ptr = i->type()->cast(); - if (!cur_tt_ptr) { - return false; - } - - if (tt_ptr && tt_ptr->sizes() != cur_tt_ptr->sizes()) { - return false; - } - tt_ptr = cur_tt_ptr; + bool allCatInputsHaveSameSize(Node * node) { + JIT_ASSERT(node->kind() == aten::cat); + std::vector inputs = node->inputs(); + if (!node->hasAttributes()) { + inputs.pop_back(); // Get rid of the dim argument } - return true; + + auto expected = inputs.at(0)->type()->cast(); + if (!expected) return false; + return std::all_of(inputs.begin(), inputs.end(), [expected](Value *v) { + auto actual = v->type()->cast(); + return actual && actual->sizes() == expected->sizes(); + }); } // Can this node produce an _output_ of a fusion group? @@ -207,7 +204,7 @@ struct GraphFuser { // this concat fusion only works when all the inputs are the same size // and we can statically infer the dimension along which we should concat // otherwise they cannot partipate in the same map - if(node->kind() == aten::cat && node->get(attr::dim) && allOutputsHaveSameSize(node)) + if(node->kind() == aten::cat && node->is_constant(attr::dim) && allCatInputsHaveSameSize(node)) return true; return false; @@ -327,12 +324,24 @@ struct GraphFuser { inputs_map[input] = subgraph.inputs()[i++]; } // add n's inputs to the fusion group's input list if we don't already have them + Node * insert_after = nullptr; for (auto input : n->inputs()) { if (inputs_map.count(input) == 0) { - auto in_group = subgraph.addInput(); - in_group->setType(input->type()); - inputs_map[input] = in_group; - group->addInput(input); + if (input->type()->isSubtypeOf(DynamicType::get())) { + auto in_group = subgraph.addInput(); + in_group->setType(input->type()); + inputs_map[input] = in_group; + group->addInput(input); + } else { + // We don't support passing in scalars as arguments to fused kernels, so we generally + // don't allow fusing tensor-scalar operations unless the scalar is constant. In those + // cases we inline the constants directly in the body of the fused group. + JIT_ASSERT(input->node()->kind() == prim::Constant); + Node * in_const = subgraph.createClone(input->node(), [](Value*) -> Value* { throw std::runtime_error("unexpected input"); }); + subgraph.prependNode(in_const); + insert_after = in_const; + inputs_map[input] = in_const->output(); + } } } // copy n into the graph, remapping its inputs to internal nodes @@ -351,7 +360,7 @@ struct GraphFuser { subgraph.inputs()[p]->replaceAllUsesWith(in_graph->output()); subgraph.eraseInput(p); } - return subgraph.prependNode(in_graph); + return insert_after ? in_graph->insertAfter(insert_after) : subgraph.prependNode(in_graph); } // turn consumer node n into a fusion group with just n inside @@ -432,7 +441,7 @@ struct GraphFuser { if (!isChunk(chunk)) return false; // and the thing being chunked is fusable into the consumer - Value * producer_for_chunk = chunk->input(); + Value * producer_for_chunk = chunk->namedInput(attr::self); if (!isFusable(producer_for_chunk->node()) || !allUsersAreThisConsumer(chunk,producer_for_chunk)) return false; // and all uses of the chunk are in this consumer @@ -457,20 +466,25 @@ struct GraphFuser { std::vector> chunked_inputs; for (auto input : producer_for_chunk_node->inputs()) { auto input_type = input->type()->cast(); + // XXX: we only work with pointwise ops in here, so we know it is valid to push + // the concat only through tensor arguments (and all other args can be safely ignored). + if (!input_type) + continue; // NB: I decided not to use cloneFrom here, because if we make cloneFrom // copy selects one day, it is definitely not what you want here (selects // have different types). // TODO: Perhaps we should use cloneFrom now, as it seems unlikely // to copy select nodes now that we have refactored to have a Value // distinct from Node. - Node * input_chunk = block->owningGraph()->create(chunk->kind(), 0); - input_chunk->copyAttributes(*chunk); + Node * input_chunk = block->owningGraph()->create(aten::chunk, 0); input_chunk->addInput(input); + input_chunk->addInput(chunk->namedInput(attr::chunks)); + input_chunk->addInput(chunk->namedInput(attr::dim)); insertAt(&insertion_point, input_chunk); chunked_inputs.emplace_back(); // alas, to not be C++17 for (auto chunk_sel : chunk->outputs()) { - auto chunk_sel_type = chunk_sel->type()->cast(); + auto chunk_sel_type = chunk_sel->type()->expect(); Value * input_chunk_sel = input_chunk->addOutput(); input_chunk_sel->setType( input_type->withSizesStrides(chunk_sel_type->sizes(), @@ -482,12 +496,20 @@ struct GraphFuser { // apply the op to each chunk of the chunked operands, // and then rewrite the graph to use them! for (auto chunk_sel : chunk->outputs()) { + auto original_inputs = producer_for_chunk_node->inputs(); Node * chunked_op = block->owningGraph()->create(producer_for_chunk_node->kind()); chunked_op->copyAttributes(*producer_for_chunk_node); // Invariant: mappable operators always produce contiguous output chunked_op->output()->setType(chunk_sel->type()->cast()->contiguous()); - for (auto by_chunk_output_idx : chunked_inputs) { - chunked_op->addInput(by_chunk_output_idx.at(chunk_sel->offset())); + auto chunked_inputs_it = chunked_inputs.begin(); + for (size_t i = 0; i < original_inputs.size(); ++i) { + if (original_inputs[i]->type()->isSubtypeOf(DynamicType::get())) { + JIT_ASSERT(chunked_inputs_it != chunked_inputs.end()); + chunked_op->addInput(chunked_inputs_it->at(chunk_sel->offset())); + ++chunked_inputs_it; + } else { + chunked_op->addInput(original_inputs[i]); + } } insertAt(&insertion_point, chunked_op); chunk_sel->replaceAllUsesWith(chunked_op->output()); diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index 030b4dc7a34395..75fb063c761a31 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -1,5 +1,6 @@ #include "torch/csrc/utils/pybind.h" #include "torch/csrc/jit/passes/onnx.h" +#include "torch/csrc/jit/passes/dead_code_elimination.h" #include "torch/csrc/autograd/function.h" #include "torch/csrc/autograd/symbolic.h" #include "torch/csrc/jit/assertions.h" @@ -194,6 +195,7 @@ void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExpo // Copy stage from original graph ctx.block->owningGraph()->setStage(old_block->owningGraph()->stage()); + EliminateDeadCode(ctx.block); } }} diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index ea256f8e1f867f..7fcd47f3a23b54 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -96,13 +96,18 @@ void fuseBroadcast(Block *b) { JIT_ASSERT(!n->hasAttribute(attr::axis)); auto input_index = n->inputs().size() - 1; - auto* expanded_rhs = n->input(input_index)->node(); - - // The expanded_rhs input isn't actually an expand, so no fusion available - if (expanded_rhs->kind() != aten::expand) continue; - if (expanded_rhs->inputs().size() != 1) continue; + auto* rhs_expand = n->input(input_index)->node(); + + // The rhs_expand input isn't actually an expand, so no fusion available + // XXX: we can't use the ->matches(...) mechanism in here, because input nodes + // have been + if (rhs_expand->kind() != aten::expand || + rhs_expand->input(1)->node()->kind() != onnx::Constant || + rhs_expand->input(2)->node()->kind() != onnx::Constant) { + continue; + } - auto* unexpanded_rhs = expanded_rhs->input(); + auto* unexpanded_rhs = rhs_expand->input(0); // We need to know what the type pre-expand is. We should basically // always have this information (because expands are only ever traced, @@ -113,7 +118,7 @@ void fuseBroadcast(Block *b) { // Not all broadcasts are supported by ONNX broadcast. at::optional axis = fusibleExpandTo( unexpanded_rhs->type()->expect()->sizes(), // from - expanded_rhs->output()->type()->expect()->sizes()); // to + rhs_expand->output()->type()->expect()->sizes()); // to if (axis == at::nullopt) continue; @@ -128,8 +133,8 @@ void fuseBroadcast(Block *b) { n->i_(attr::axis, axis.value()); } } - if (!expanded_rhs->hasUses()) { - expanded_rhs->destroy(); + if (!rhs_expand->hasUses()) { + rhs_expand->destroy(); } } } diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index 2ee777aee0e66f..ac0c96232647e9 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -30,7 +30,7 @@ void PeepholeOptimize(Block * block) { if (auto input_type = node->namedInput(attr::self)->type()->cast()) { auto expanded_sizes = node->get>(attr::size); if (expanded_sizes == input_type->sizes()) { - node->output()->replaceAllUsesWith(node->input()); + node->output()->replaceAllUsesWith(node->namedInput(attr::self)); } } } else if (node->matches("aten::t(Tensor self) -> Tensor")) { diff --git a/torch/csrc/jit/passes/remove_expands.cpp b/torch/csrc/jit/passes/remove_expands.cpp index f0f591cac59ec9..93d53e54819bbd 100644 --- a/torch/csrc/jit/passes/remove_expands.cpp +++ b/torch/csrc/jit/passes/remove_expands.cpp @@ -9,7 +9,7 @@ static void RemoveExpands(Block* block) { for (auto sub : it->blocks()) RemoveExpands(sub); if (it->kind() == aten::expand && it->get(attr::implicit) != static_cast(0)) { - it->output()->replaceAllUsesWith(it->input()); + it->output()->replaceAllUsesWith(it->namedInput(attr::self)); it.destroyCurrent(); } } diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index e6136b03c4414e..f1fef4c5247ea0 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -124,9 +124,11 @@ void broadcastBinary(Node *node, std::vector& types, size_t idx1, if (input_type->sizes() == expected_size) return; auto graph = node->owningGraph(); - Node *expand = graph->create(aten::expand, {node->inputs().at(input_idx)}) - ->is_(attr::size, expected_size) - ->i_(attr::implicit, 0) + WithInsertPoint point_guard { node }; + Node *expand = graph->create(aten::expand, + {node->inputs().at(input_idx), + insertConstant(*graph, expected_size), + insertConstant(*graph, 0)}) ->insertBefore(node); PropagateShapeOnNode(expand); node->replaceInput(input_idx, expand->output()); @@ -280,6 +282,7 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) { node->matches("aten::min(Tensor self, Tensor other) -> Tensor") || node->matches("aten::max(Tensor self, Tensor other) -> Tensor") || node->matches("aten::lt(Tensor self, Tensor other) -> Tensor") || + node->matches("aten::le(Tensor self, Tensor other) -> Tensor") || node->matches("aten::gt(Tensor self, Tensor other) -> Tensor") || node->matches("aten::ge(Tensor self, Tensor other) -> Tensor") || node->matches("aten::eq(Tensor self, Tensor other) -> Tensor") || @@ -451,7 +454,8 @@ void PropagateShapeOnBlock(Block * block, bool insert_expands) { } } -} +} // anonymous namespace + void PropagateInputShapes(Graph & graph, const ArgumentSpec & spec) { JIT_ASSERT(graph.inputs().size() == spec.size()); for(size_t i = 0; i < spec.size(); ++i) { @@ -462,4 +466,29 @@ void PropagateInputShapes(Graph & graph, const ArgumentSpec & spec) { PropagateShapeOnBlock(graph.block()); } +namespace { + +void EraseShapeInformation(at::ArrayRef vals) { + for (Value * v : vals) { + v->setType(unshapedType(v->type())); + } +} + +void EraseShapeInformation(Block * b) { + EraseShapeInformation(b->inputs()); + EraseShapeInformation(b->outputs()); + for (Node * n : b->nodes()) { + EraseShapeInformation(n->outputs()); + for (Block *sb : n->blocks()) { + EraseShapeInformation(sb); + } + } +} + +} // anonymous namespace + +void EraseShapeInformation(Graph & graph) { + EraseShapeInformation(graph.block()); +} + }} diff --git a/torch/csrc/jit/passes/shape_analysis.h b/torch/csrc/jit/passes/shape_analysis.h index 1b38cbbe5739a4..199d376e87ddec 100644 --- a/torch/csrc/jit/passes/shape_analysis.h +++ b/torch/csrc/jit/passes/shape_analysis.h @@ -3,8 +3,11 @@ #include "torch/csrc/WindowsTorchApiMacro.h" namespace torch { namespace jit { + struct Graph; struct ArgumentSpec; + +void EraseShapeInformation(Graph & graph); TORCH_API void PropagateInputShapes(Graph & graph, const ArgumentSpec & spec); }} diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 05dbe341143d79..81211085569953 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -281,7 +281,10 @@ void initPythonIRBindings(PyObject * module_) { #undef VS - py::class_>(m, "Block"); + py::class_>(m, "Block") + .def("nodes",[](Block &b) { + return py::make_iterator(b.nodes().begin(), b.nodes().end()); + }); #define NS(name) \ def(#name,&Node :: name) @@ -461,6 +464,8 @@ void initPythonIRBindings(PyObject * module_) { } return types; }); + py::class_>(m, "ListType") + .def_static("ofInts", &ListType::ofInts); py::class_(m,"Use") .def_readonly("user",&Use::user) diff --git a/torch/csrc/jit/python_tracer.cpp b/torch/csrc/jit/python_tracer.cpp index 78247017b9e9c6..7439b2b5e334cc 100644 --- a/torch/csrc/jit/python_tracer.cpp +++ b/torch/csrc/jit/python_tracer.cpp @@ -77,12 +77,25 @@ PreTraceInfo preRecordPythonTrace(THPObjectPtr pyobj, if(!apply) { throw python_error(); } - return makePreTraceInfo(inputs, [&](const std::shared_ptr& state, Graph& graph) { - return graph.createPythonOp( - std::move(apply), - arg_types, - std::move(scalar_args)); - }); + + PreTraceInfo info; + auto & state = getTracingState(); + auto & graph = state->graph; + + Node *n = info.n = graph->createPythonOp( + std::move(apply), + arg_types, + std::move(scalar_args)); + recordSourceLocation(n); + + for (const Variable & input : inputs) { + n->addInput(getValueTrace(input)); + } + + // NB: Order matters. This must append after inputs but before outputs. + graph->appendNode(n); + + return info; } void pythonRecordSourceLocation(Node* n) { diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index aa39746d779ec3..6cf7b37d4f43c0 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -386,72 +386,6 @@ at::optional> getIntListAttribute(at::optional N, return std::vector(*N, *r); } -// try to turn constant inputs into attributes -void liftConstantAttributes(const FunctionSchema& schema, Node* node) { - // we shouldn't start with attributes, just inputs - JIT_ASSERT(!node->hasAttributes()); - std::vector new_inputs; - Attributes attributes; - for(size_t i = 0, n = 0; i < schema.arguments.size(); ++i) { - const auto& arg = schema.arguments[i]; - // this was a builtin with a vararg list lowered, - if(*arg.type == *ListType::ofTensors()) { - // we need to skip all the vararg nodes, and continue parsing the - // possible attribute nodes - size_t vararg_list_size = node->inputs().size() - (schema.arguments.size() - 1); - while(n < i + vararg_list_size) { - new_inputs.push_back(node->input(n++)); - } - continue; - } - auto input = node->input(n++); - switch(arg.type->kind()) { - case TypeKind::IntType:{ - auto r = constant_as(input); - if(!r) - return; - attributes.i_(Symbol::attr(arg.name), *r); - } break; - case TypeKind::FloatType: { - auto r = constant_as(input); - if(!r) - return; - attributes.f_(Symbol::attr(arg.name), *r); - } break; - case TypeKind::NumberType: { - auto r = constant_as(input); - if(!r) - return; - attributes.t_(Symbol::attr(arg.name), r->toTensor()); - } break; - case TypeKind::ListType: { - auto elem = arg.type->expect()->getElementType(); - if(elem->kind() == TypeKind::IntType) { - auto r = getIntListAttribute(arg.N, input); - if(!r) - return; - attributes.is_(Symbol::attr(arg.name), *r); - } else { - // only IntLists can become attributes, other - // types are not attribute-able - new_inputs.push_back(input); - } - } break; - default: - new_inputs.push_back(input); - } - } - // nothing changed no need to modify the node - if(!attributes.hasAttributes()) - return; - - node->removeAllInputs(); - for(Value* input : new_inputs) { - node->addInput(input); - } - node->copyAttributes(attributes); -} - at::ArrayRef createTupleUnpack(Value* v) { // small peephole optimization to ensure IntList attributes can still turn // into constants e.g. in x.expand([3, 4]) @@ -516,10 +450,8 @@ at::optional> tryMatchSchema( return at::nullopt; } positional_inputs[i] = NamedValue( - loc, - i, - insertConstant(graph, *default_value, loc) - ->setType(schema.arguments[i].type)); + loc, i, + insertConstant(graph, *default_value, loc)); } // check input types @@ -579,10 +511,6 @@ static std::shared_ptr tryEmitBuiltin( return nullptr; // we successfully matched this schema, construct the node - // note: we always construct purely positional nodes here - // the pass liftConstantAttributes replaces the node with with one that - // uses attributes if all the attributes ended up as constants - NodeKind kind(Symbol::aten(name)); auto n = graph->insertNode(graph->create(kind, *flat_inputs, 0)) ->setSourceLocation(std::make_shared(loc)); @@ -603,9 +531,6 @@ static std::shared_ptr tryEmitBuiltin( } } - if(op->hasAttributedVersion()) - liftConstantAttributes(op->schema, n); - // assert that we did indeed create an op that has implementation // otherwise schema and dispatch are not in sync getOperation(n); diff --git a/torch/csrc/jit/symbolic_variable.h b/torch/csrc/jit/symbolic_variable.h index ff9e5149068a81..e4d2f98ba0ea0f 100644 --- a/torch/csrc/jit/symbolic_variable.h +++ b/torch/csrc/jit/symbolic_variable.h @@ -1,6 +1,7 @@ #pragma once #include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/constants.h" namespace torch { namespace jit { @@ -56,90 +57,51 @@ struct SymbolicVariable { return create(aten::mul, {*this, rhs})[0].typeLike(*this); } SymbolicVariable operator*(at::Scalar rhs) const { - if(isConstInt(rhs, 1)) + if (isConstInt(rhs, 1)) return *this; - Node * n; - auto r = create(aten::mul, {*this}, 1, &n)[0]; - n->t_(attr::other, rhs.toTensor()); - return r; + return (*this) * insertConstant(rhs); } SymbolicVariable operator>(at::Scalar rhs) const { - Node * n; - auto r = create(aten::gt, {*this}, 1, &n)[0].typeLikeWithScalarType(*this, at::kByte); - n->t_(attr::other, rhs.toTensor()); - return r; + return create(aten::gt, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte); } SymbolicVariable operator<(at::Scalar rhs) const { - Node * n; - auto r = create(aten::lt, {*this}, 1, &n)[0].typeLikeWithScalarType(*this, at::kByte); - n->t_(attr::other, rhs.toTensor()); - return r; + return create(aten::lt, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte); } SymbolicVariable operator>=(at::Scalar rhs) const { - Node * n; - auto r = create(aten::ge, {*this}, 1, &n)[0].typeLikeWithScalarType(*this, at::kByte); - n->t_(attr::other, rhs.toTensor()); - return r; + return create(aten::ge, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte); } SymbolicVariable operator<=(at::Scalar rhs) const { - Node * n; - auto r = create(aten::le, {*this}, 1, &n)[0].typeLikeWithScalarType(*this, at::kByte); - n->t_(attr::other, rhs.toTensor()); - return r; + return create(aten::le, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte); } SymbolicVariable operator==(at::Scalar rhs) const { - Node * n; - auto r = create(aten::eq, {*this}, 1, &n)[0].typeLikeWithScalarType(*this, at::kByte); - n->t_(attr::other, rhs.toTensor()); - return r; + return create(aten::eq, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte); } SymbolicVariable operator!=(at::Scalar rhs) const { - Node * n; - auto r = create(aten::ne, {*this}, 1, &n)[0].typeLikeWithScalarType(*this, at::kByte); - n->t_(attr::other, rhs.toTensor()); - return r; + return create(aten::ne, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte); } SymbolicVariable operator+(const SymbolicVariable rhs) const { - Node * n; - auto r = create(aten::add, {*this, rhs}, 1, &n)[0].typeLike(*this); - n->t_(attr::alpha, at::Scalar(1).toTensor()); - return r; + return create(aten::add, {*this, rhs, insertConstant(1)})[0].typeLike(*this); } SymbolicVariable operator+(at::Scalar rhs) const { - Node * n; - auto r = create(aten::add, {*this}, 1, &n)[0].typeLike(*this); - n->t_(attr::alpha, at::Scalar(1).toTensor()); - n->t_(attr::other, rhs.toTensor()); - return r; + return (*this) + insertConstant(rhs); } SymbolicVariable operator-() const { return create(aten::neg, {*this})[0].typeLike(*this); } SymbolicVariable operator-(const SymbolicVariable rhs) const { - Node *n; - auto r = create(aten::sub, {*this, rhs}, 1, &n)[0].typeLike(*this); - n->t_(attr::alpha, at::Scalar(1).toTensor()); - return r; + return create(aten::sub, {*this, rhs, insertConstant(1)})[0].typeLike(*this); } SymbolicVariable operator/(at::Scalar rhs) const { - Node *n; - auto r = create(aten::div, {*this}, 1, &n)[0].typeLike(*this); - n->t_(attr::other, rhs.toTensor()); - return r; + return create(aten::div, {*this, insertConstant(rhs)})[0].typeLike(*this); } SymbolicVariable operator%(at::Scalar rhs) const { - Node *n; - auto r = create(aten::remainder, {*this}, 1, &n)[0].typeLike(*this); - n->t_(attr::other, rhs.toTensor()); - return r; + return create(aten::remainder, {*this, insertConstant(rhs)})[0].typeLike(*this); } SymbolicVariable mm(const SymbolicVariable rhs) const { - auto r = create(t("mm"), {*this, rhs})[0]; - return r; + return create(t("mm"), {*this, rhs})[0]; } SymbolicVariable t() const { - auto r = create(t("t"), {*this})[0]; - return r; + return create(t("t"), {*this})[0]; } SymbolicVariable sigmoid() const { return create(aten::sigmoid, {*this})[0].typeLike(*this); @@ -147,88 +109,73 @@ struct SymbolicVariable { SymbolicVariable tanh() const { return create(aten::tanh, {*this})[0].typeLike(*this); } - std::vector chunk(int32_t chunks, uint32_t dim) const { - Node * n; - auto r = create(t("chunk"), { *this }, chunks, &n); - n->i_(a("chunks"), chunks) - ->i_(a("dim"), dim); - return r; + std::vector chunk(int64_t chunks, int dim) const { + return create(t("chunk"), { *this , insertConstant(chunks), insertConstant(dim) }, chunks); } SymbolicVariable type_as(const SymbolicVariable rhs) const { return create(aten::type_as, {*this, rhs})[0].typeLikeWithRhsScalarType(*this, rhs); } SymbolicVariable narrow(int dim, int64_t start, int64_t length) const { - Node * n; - auto r = create(t("narrow"), { *this }, 1, &n)[0]; - n->i_(a("dim"), dim) - ->i_(a("start"), start) - ->i_(a("length"), length); - return r; + return create(t("narrow"), { *this, insertConstant(dim), insertConstant(start), insertConstant(length) }, 1)[0]; } static SymbolicVariable cat(ArrayRef inputs, Value* dim) { - Node* n; std::vector all_inputs = inputs; all_inputs.push_back(dim); - auto r = create(aten::cat, all_inputs, 1, &n)[0]; - return r; + return create(aten::cat, all_inputs)[0]; } - static SymbolicVariable cat(ArrayRef inputs, int32_t dim) { - Node* n; - auto r = create(aten::cat, inputs, 1, &n)[0]; - n->i_(attr::dim, dim); - return r; + static SymbolicVariable cat(ArrayRef inputs, int dim) { + JIT_ASSERT(inputs.size() > 0); + return SymbolicVariable::cat(inputs, inputs[0].insertConstant(dim)); } - static SymbolicVariable stack(ArrayRef inputs, int32_t dim) { - Node* n; - auto r = create(aten::stack, inputs, 1, &n)[0]; - n->i_(attr::dim, dim); - return r; + static SymbolicVariable stack(ArrayRef inputs, Value* dim) { + std::vector all_inputs = inputs; + all_inputs.push_back(dim); + return create(aten::stack, all_inputs)[0]; + } + static SymbolicVariable stack(ArrayRef inputs, int dim) { + JIT_ASSERT(inputs.size() > 0); + return SymbolicVariable::stack(inputs, inputs[0].insertConstant(dim)); } SymbolicVariable sum() const { - auto r = create(t("sum"), {*this})[0]; - return r; + return create(t("sum"), {*this})[0]; } SymbolicVariable sum(int dim, bool keepdim) const { - Node * n; - auto r = create(t("sum"), {*this}, 1, &n)[0]; - n->is_(a("dim"), {dim}) - ->i_(a("keepdim"), keepdim); - return r; + return create(t("sum"), {*this, insertConstant(at::IntList{dim}), insertConstant(keepdim)})[0]; } SymbolicVariable squeeze(Value* dim) const { - Node * n; - auto r = create(t("squeeze"), {*this, dim}, 1, &n)[0]; - return r; + return create(t("squeeze"), {*this, dim})[0]; } SymbolicVariable squeeze(int dim) const { - Node * n; - auto r = create(t("squeeze"), {*this}, 1, &n)[0]; - n->i_(a("dim"), dim); - return r; + return squeeze(insertConstant(dim)); + } + SymbolicVariable unsqueeze(Value* dim) const { + return create(t("unsqueeze"), {*this, dim})[0]; } SymbolicVariable unsqueeze(int dim) const { - Node * n; - auto r = create(t("unsqueeze"), {*this}, 1, &n)[0]; - n->i_(a("dim"), dim); - return r; + return unsqueeze(insertConstant(dim)); + } + SymbolicVariable view(Value* sizes) const { + return create(aten::view, {*this, sizes})[0]; } SymbolicVariable view(std::vector sizes) const { - Node *n; - auto r = create(aten::view, {*this}, 1, &n)[0]; - n->is_(a("size"), std::move(sizes)); - return r; + return view(insertConstant(sizes)); + } + SymbolicVariable reshape(Value* sizes) const { + return create(aten::reshape, {*this, sizes})[0]; + } + SymbolicVariable reshape(std::vector sizes) const { + return reshape(insertConstant(sizes)); } SymbolicVariable addmm(SymbolicVariable mat1, SymbolicVariable mat2) const { - Node *n; - auto r = create(aten::addmm, {*this, mat1, mat2}, 1, &n)[0]; - n->t_(a("alpha"), at::CPU(at::kFloat).scalarTensor(1.0)); - n->t_(a("beta"), at::CPU(at::kFloat).scalarTensor(1.0)); - return r; + return create(aten::addmm, {*this, mat1, mat2, insertConstant(1.0), insertConstant(1.0)})[0]; } Value * value() const { return v; } private: + Value * insertConstant(IValue value) const { + return jit::insertConstant(*v->owningGraph(), value); + } SymbolicVariable typeLike(SymbolicVariable other) { if (auto other_type = other.v->type()->cast()) v->setType(other_type->contiguous()); diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp index 5c998e3fc690bf..aec6eb4ddc9447 100644 --- a/torch/csrc/jit/tracer.cpp +++ b/torch/csrc/jit/tracer.cpp @@ -19,6 +19,51 @@ namespace torch { namespace jit { namespace tracer { //////////////////////////////////////////////////////////////////////////////// namespace detail { +template +void genericAddInput(Node *n, T value) { + n->addInput(insertConstant(*n->owningGraph(), value)); +} + +void badArgType() { + throw std::runtime_error("Found an unsupported argument type in the JIT tracer. File a bug report."); +} + + +void addInputs(Node *n, const char * name, int64_t value) { genericAddInput(n, value); } +void addInputs(Node *n, const char * name, bool value) { genericAddInput(n, value); } +void addInputs(Node *n, const char * name, double value) { genericAddInput(n, value); } +void addInputs(Node *n, const char * name, const at::Scalar& value) { genericAddInput(n, value); } +void addInputs(Node *n, const char * name, const at::Tensor& value) { n->addInput(getValueTrace(value)); } +void addInputs(Node *n, const char * name, const std::string& value) { badArgType(); } +void addInputs(Node *n, const char * name, const at::SparseTensorRef& value) { badArgType(); } + +void addInputs(Node *n, const char * name, at::TensorList value) { + for (auto & t : value) { + n->addInput(getValueTrace(t)); + } +} + +void addInputs(Node *n, const char * name, at::IntList value) { + using ArgumentStash = jit::tracer::ArgumentStash; + std::vector info = ArgumentStash::hasIntList(name) ? + ArgumentStash::popIntList(name) : + ArgumentStash::IntListTrace(value.size()); + + auto& g = getTracingState()->graph; + for (size_t i = 0; i < info.size(); ++i) { + if (info[i] != nullptr) continue; + info[i] = insertConstant(*g, value[i]); + } + for (jit::Value* v : info) { + if (*v->type() != *jit::IntType::get()) { + throw std::runtime_error( + "Type mismatch in setposattr for IntList. Check that your program " + "is valid without tracing, and please file a bug report if it is."); + } + } + n->addInput(g->insertNode(g->createList(jit::IntType::get(), info))->output()); +} + thread_local std::shared_ptr tracing_state; } // namespace detail @@ -36,13 +81,6 @@ TracingState::TracingState() TracingState::~TracingState() = default; -PreTraceInfo preRecordTrace(Symbol op, - at::ArrayRef inputs) { - return makePreTraceInfo(inputs, [&op](const std::shared_ptr& state, Graph& graph) { - return graph.create(op, 0 /* initial outputs */); - }); -} - void postRecordTrace(const PreTraceInfo& info, at::ArrayRef outputs) { for (size_t i = 0; i < outputs.size(); i++) { diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h index bde3edf52221e1..c9780119a385a0 100644 --- a/torch/csrc/jit/tracer.h +++ b/torch/csrc/jit/tracer.h @@ -193,36 +193,58 @@ struct PreTraceInfo { Node *n; }; -TORCH_API PreTraceInfo preRecordTrace(Symbol op, at::ArrayRef inputs); -TORCH_API void postRecordTrace(const PreTraceInfo& info, at::ArrayRef outputs); TORCH_API void recordSourceLocation(Node* n); TORCH_API void setRecordSourceLocation(void (*v)(Node*)); -// We must record the nodes of inputs before we actually carry out -// the operation, because an inplace operation may destroy the information -// we're interested in. See #4480. -template -PreTraceInfo makePreTraceInfo(at::ArrayRef inputs, F ctor) { +namespace detail { + +// NB: those serve both as an intermediate steps in addInputs below, +// as well as the overloads that terminate template recursion +void addInputs(Node *n, const char * name, int64_t value); +void addInputs(Node *n, const char * name, bool value); +void addInputs(Node *n, const char * name, double value); +void addInputs(Node *n, const char * name, const at::Scalar& value); +void addInputs(Node *n, const char * name, const at::Tensor& value); +void addInputs(Node *n, const char * name, at::IntList value); +void addInputs(Node *n, const char * name, at::TensorList value); +void addInputs(Node *n, const char * name, const std::string& value); +void addInputs(Node *n, const char * name, const at::SparseTensorRef& value); + +template +void addInputs(Node *n, const char * name, std::array value) { + throw std::runtime_error("Found an unsupported argument type in the JIT tracer. File a bug report."); +} + +template +void addInputs(Node *n, const char * arg_name, T arg, const char * next_arg_name, Args... args) { + addInputs(n, arg_name, arg); + addInputs(n, next_arg_name, args...); +} + +} // namespace detail + +// NB: if you change this function, you might want to take a look at +// preRecordPythonTrace from python_tracer.cpp +template +PreTraceInfo preRecordTrace(Symbol op, Args... inputs) { PreTraceInfo info; auto & state = getTracingState(); auto & graph = state->graph; - Node *n = ctor(state, *graph); + Node * n = info.n = graph->create(op, /*outputs=*/0); recordSourceLocation(n); - for (const Variable & input : inputs) { - n->addInput(getValueTrace(input)); - } + detail::addInputs(n, inputs...); // NB: Order matters. This must append after inputs but before outputs. graph->appendNode(n); - info.n = n; - return info; } +TORCH_API void postRecordTrace(const PreTraceInfo& info, at::ArrayRef outputs); + TORCH_API autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim); }}} // namespace torch::jit::tracer diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 4ac6ca7f887195..99fd4dd1b25a7b 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -11,7 +11,7 @@ import torch.onnx.utils from collections import Iterable -from functools import partial +from functools import partial, wraps import itertools # EDITING THIS FILE? READ THIS FIRST! @@ -32,6 +32,59 @@ # --------------------------------------------------------------------- +def _parse_arg(value, desc): + if desc == 'v' or not _is_value(value): + return value + + if value.node().kind() != 'onnx::Constant': + raise RuntimeError("ONNX symbolic expected a constant value in the trace") + tval = value.node()['value'] + if desc == 'i': + return int(tval) + elif desc == 'f': + return float(tval) + elif desc == 't': + return tval + elif desc == 'is': + return [int(v) for v in tval] + else: + raise RuntimeError("Casting constants to `{}` is not implemented".format(desc)) + + +def _maybe_get_const(value, desc): + if _is_value(value) and value.node().kind() == 'onnx::Constant': + return _parse_arg(value, desc) + return value + + +def _maybe_get_scalar(value): + value_t = _maybe_get_const(value, 't') + if isinstance(value_t, torch.Tensor) and value_t.shape == (): + return value_t + return value + + +def _get_const(value, desc, arg_name): + if _is_value(value) and value.node().kind() != 'onnx::Constant': + raise RuntimeError("ONNX symbolic expected a constant value of the {} argument".format(arg_name)) + return _parse_arg(value, desc) + + +def parse_args(*arg_descriptors): + def decorator(fn): + def wrapper(g, *args): + assert len(arg_descriptors) == len(args) + args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] + return fn(g, *args) + # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround + try: + wrapper = wraps(fn)(wrapper) + except Exception: + pass + return wrapper + return decorator + + def _scalar(x): """Convert a scalar tensor into a Python value.""" assert x.numel() == 1 @@ -137,27 +190,33 @@ def unused(g): return g.op("prim::Undefined") +@parse_args('v', 'v', 't') def add(g, self, other, alpha): if _scalar(alpha) != 1: return _unimplemented("add", "alpha != 1") # See Note [Pointwise by scalar] + other = _maybe_get_scalar(other) return g.op("Add", self, _if_scalar_type_as(other, self), **_broadcast_if_scalar(other)) +@parse_args('v', 'v', 't') def sub(g, self, other, alpha): if _scalar(alpha) != 1: return _unimplemented("sub", "alpha != 1") # See Note [Pointwise by scalar] + other = _maybe_get_scalar(other) return g.op("Sub", self, _if_scalar_type_as(other, self), **_broadcast_if_scalar(other)) def mul(g, self, other): # See Note [Pointwise by scalar] + other = _maybe_get_scalar(other) return g.op("Mul", self, _if_scalar_type_as(other, self), **_broadcast_if_scalar(other)) def div(g, self, other): # See Note [Pointwise by scalar] + other = _maybe_get_scalar(other) return g.op("Div", self, _if_scalar_type_as(other, self), **_broadcast_if_scalar(other)) @@ -166,9 +225,9 @@ def reciprocal(g, self): # This syntax is Python 2 portable -def cat(g, *tensors, **kwargs): - dim = kwargs.pop("dim") - assert not kwargs +def cat(g, *args): + dim = _get_const(args[-1], 'i', 'dim') + tensors = args[:-1] return g.op("Concat", *tensors, axis_i=dim) @@ -188,6 +247,7 @@ def matmul(g, self, other): return g.op("MatMul", self, other) +@parse_args('v', 'v', 'v', 't', 't') def addmm(g, self, mat1, mat2, beta, alpha): return g.op("Gemm", mat1, mat2, self, beta_f=_scalar(beta), alpha_f=_scalar(alpha)) @@ -211,12 +271,13 @@ def sigmoid(g, self): def _reduce_op_symbolic(onnx_op_name): def symbolic(g, self, dim=None, keepdim=None): params = {} - if dim is not None: - if isinstance(dim, numbers.Number): - dim = [dim] - params['axes_i'] = dim - params['keepdims_i'] = int(bool(keepdim)) - return g.op(onnx_op_name, self, **params) + if dim is None: + # all-reduce path + return g.op(onnx_op_name, self, keepdims_i=0) + else: + # dim-reduce path + dim, keepdim = _get_const(dim, 'i', 'dim'), _get_const(keepdim, 'i', 'keepdim') + return g.op(onnx_op_name, self, axes_i=[dim], keepdims_i=keepdim) return symbolic mean = _reduce_op_symbolic('ReduceMean') @@ -224,6 +285,7 @@ def symbolic(g, self, dim=None, keepdim=None): prod = _reduce_op_symbolic('ReduceProd') +@parse_args('v', 'i') def cumsum(g, input, dim): return g.op("ATen", input, operator_s="cumsum", dim_i=dim) @@ -241,6 +303,7 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): return g.op("Gather", weight, indices) +@parse_args('v', 'v', 'v', 'i', 'i', 'i') def embedding_bag(g, embedding_matrix, indices, @@ -260,14 +323,11 @@ def embedding_bag(g, def size(g, self, dim): - if _is_value(dim): - if dim.node().kind() != 'onnx::Constant': - raise RuntimeError("ONNX export only supports constant dim values in .size()") - dim = int(dim.node().t('value')) full_shape = g.op("Shape", self) - return select(g, full_shape, dim=0, index=dim) + return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim) +@parse_args('v', 'i', 'i') def transpose(g, self, dim0, dim1): if dim0 == dim1: # micro-optimization return self @@ -278,6 +338,7 @@ def transpose(g, self, dim0, dim1): return g.op("Transpose", self, perm_i=axes) +@parse_args('v', 'is') def permute(g, self, dims): if dims == list(range(0, len(dims))): return self @@ -285,6 +346,7 @@ def permute(g, self, dims): def view(g, self, size): + size = _maybe_get_const(size, 'is') if _is_value(size): shape = size else: @@ -296,16 +358,12 @@ def view(g, self, size): return g.op("Reshape", self, shape) -def stack(g, *tensors, **kwargs): - dim = kwargs.pop('dim') - if kwargs: - raise RuntimeError("Unexpected kwargs: " + ','.join(kwargs.keys())) - if len(tensors) < 1: - raise RuntimeError("Expected at least one argument to stack node") - unsqueezed = [g.op("Unsqueeze", t, axes_i=[dim]) for t in tensors] - return g.op("Concat", *unsqueezed, axis_i=dim) +def stack(g, *args): + unsqueezed = [g.op("Unsqueeze", t, axes_i=[dim]) for t in args[:-1]] + [args[-1]] + return concat(g, *unsqueezed) +@parse_args('v', 'i', 'i') def split(g, self, split_size, dim): size = self.type().sizes()[dim] splits = [split_size] * (size // split_size) @@ -319,11 +377,13 @@ def split(g, self, split_size, dim): # less sensitive to changes in input size. # TODO: Once we have proper scoping, stop reimplementing chunk, delete this # method, and use the desugared version +@parse_args('v', 'i', 'i') def chunk(g, self, chunks, dim): split_size = (self.type().sizes()[dim] + chunks - 1) // chunks return split(g, self, split_size, dim) +@parse_args('v', 'i', 'i') def select(g, self, dim, index): slice_node = g.op("Slice", self, axes_i=[dim], starts_i=[index], ends_i=[index + 1]) return g.op("Squeeze", slice_node, axes_i=[dim]) @@ -336,7 +396,7 @@ def squeeze(g, self, dim=None): if size == 1: dims.append(i) else: - dims = [dim] + dims = [_get_const(dim, 'i', 'dim')] return g.op("Squeeze", self, axes_i=dims) @@ -348,6 +408,7 @@ def relu(g, input): return g.op("Relu", input) +@parse_args('v', 't', 't') def threshold(g, self, threshold, value): # See Note [Export inplace] if _scalar(threshold) != 0: @@ -358,11 +419,13 @@ def threshold(g, self, threshold, value): def leaky_relu(g, input, negative_slope, inplace=False): + negative_slope = _get_const(negative_slope, 't', 'negative_slope') # See Note [Export inplace] # TODO: Talk to ONNX about unconditional cast of scalar to float return g.op("LeakyRelu", input, alpha_f=_scalar(negative_slope)) +@parse_args('v', 'i') def glu(g, input, dim): assert input.type().sizes()[dim] % 2 == 0 @@ -370,7 +433,8 @@ def glu(g, input, dim): return g.op('Mul', first, g.op('Sigmoid', second)) -def softmax(g, input, dim=None): +@parse_args('v', 'i') +def softmax(g, input, dim): # Softmax does normalization at vector level. # PyTorch and ONNX use different strategies to split the input tensor into vectors. # Thus dim and axis have different meanings. @@ -394,12 +458,14 @@ def softmax(g, input, dim=None): return g.op('Softmax', input, axis_i=dim) +@parse_args('v', 't', 'v') def softplus(g, self, beta, threshold): if beta != 1: return _unimplemented("beta", "has to be 1") return g.op('Softplus', self) +@parse_args('v', 'is', 'is', 'is', 'is', 'i') def max_pool1d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode): if ceil_mode: return _unimplemented("max_pool1d_with_indices", "ceil_mode") @@ -414,6 +480,7 @@ def max_pool1d_with_indices(g, input, kernel_size, stride, padding, dilation, ce return r, None +@parse_args('v', 'is', 'is', 'is', 'is', 'i') def max_pool2d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode): if ceil_mode: return _unimplemented("max_pool2d_with_indices", "ceil_mode") @@ -428,6 +495,7 @@ def max_pool2d_with_indices(g, input, kernel_size, stride, padding, dilation, ce return r, None +@parse_args('v', 'is', 'is', 'is', 'is', 'i') def max_pool3d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode): if ceil_mode: return _unimplemented("max_pool3d_with_indices", "ceil_mode") @@ -443,6 +511,7 @@ def max_pool3d_with_indices(g, input, kernel_size, stride, padding, dilation, ce def _avg_pool(name, tuple_fn): + @parse_args('v', 'is', 'is', 'is', 'i', 'i') def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad): if ceil_mode: return _unimplemented("avg_pool2d", "ceil_mode") @@ -469,6 +538,7 @@ def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include avg_pool3d = _avg_pool('avg_pool3d', _triple) +@parse_args('v', 'is') def reflection_pad(g, input, padding): from torch.autograd._functions.utils import prepare_onnx_paddings mode = "reflect" @@ -476,6 +546,7 @@ def reflection_pad(g, input, padding): return g.op("Pad", input, pads_i=paddings, mode_s=mode) +@parse_args('v', 'is') def replication_pad(g, input, padding): from torch.autograd._functions.utils import prepare_onnx_paddings mode = "edge" @@ -491,6 +562,7 @@ def replication_pad(g, input, padding): replication_pad3d = replication_pad +@parse_args('v', 'is') def upsample_nearest2d(g, input, output_size): return g.op("Upsample", input, height_scale_f=float(output_size[-2]) / input.type().sizes()[-2], @@ -498,6 +570,7 @@ def upsample_nearest2d(g, input, output_size): mode_s="nearest") +@parse_args('v', 'is', 'i') def upsample_bilinear2d(g, input, output_size, align_corners): if align_corners: return _unimplemented("upsample_bilinear2d", "align_corners == True") @@ -508,10 +581,12 @@ def upsample_bilinear2d(g, input, output_size, align_corners): def gt(g, input, other): + other = _maybe_get_scalar(other) return g.op("Greater", input, _if_scalar_type_as(other, input), **_broadcast_if_scalar(other)) def lt(g, input, other): + other = _maybe_get_scalar(other) return g.op("Less", input, _if_scalar_type_as(other, input), **_broadcast_if_scalar(other)) @@ -523,10 +598,12 @@ def le(g, input, other): return g.op("Not", gt(g, other, input)) +@parse_args('v', 'i') def log_softmax(g, input, dim=None): return g.op("LogSoftmax", input, axis_i=dim) +@parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is', 'i', 'i', 'i', 'i') def _convolution(g, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled): weight_size = weight.type().sizes() @@ -560,6 +637,7 @@ def _convolution(g, input, weight, bias, stride, padding, dilation, return n +@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i') def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled): input_sizes = input.type().sizes() if len(input_sizes) == 2: @@ -586,10 +664,12 @@ def batch_norm(g, input, weight, bias, running_mean, running_var, training, mome return res +@parse_args('v', 'i', 'i', 'i') def unfold(g, input, dimension, size, step): return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step) +@parse_args('v', 't', 't') def elu(g, input, alpha, scale): if scale and scale != 1.: return _unimplemented("scale", "does not support scale in Elu") @@ -601,12 +681,13 @@ def selu(g, input): return g.op("Selu", input) -def index_select(g, self, index, dim): +@parse_args('v', 'i', 'v') +def index_select(g, self, dim, index): return g.op("Gather", self, index, axis_i=dim) -def index_put(g, *inputs, **kwargs): - return g.op("ATen", *inputs, operator_s='index_put', **kwargs) +def index_put(g, *inputs): + return g.op("ATen", *inputs, operator_s='index_put') def type_as(g, self, other): @@ -631,29 +712,33 @@ def abs(g, self): def pow(g, self, exponent): + exponent = _maybe_get_scalar(exponent) return g.op("Pow", self, _if_scalar_type_as(exponent, self), **_broadcast_if_scalar(exponent)) +@parse_args('v', 'f', 'f') def clamp(g, self, min, max): return g.op("Clip", self, min_f=min, max_f=max) +@parse_args('v', 'f') def clamp_min(g, self, min): return g.op("Clip", self, min_f=min) +@parse_args('v', 'f') def clamp_max(g, self, max): return g.op("Clip", self, max_f=max) # torch.max (same for torch.min) actually has two interfaces smashed together: # torch.max(x, dim, keepdim) and torch.max(x, y) -def max(g, self, *args, **kwargs): - dim = kwargs.get("dim", None) - if dim is None and isinstance(args[0], numbers.Number): - dim = args[0] - if dim is not None: - keepdim = kwargs.get("keepdim", False) +def max(g, self, dim_or_y, keepdim=None): + if keepdim is None: + return g.op("Max", self, dim_or_y) + else: + dim = _get_const(dim_or_y, 'i', 'dim') + keepdim = _get_const(keepdim, 'i', 'keepdim') # TODO: export it as ReduceMax return g.op("ATen", self, @@ -661,27 +746,21 @@ def max(g, self, *args, **kwargs): dim_i=dim, keepdim_i=keepdim, outputs=2) - else: - (other,) = args - return g.op("Max", self, other) -def min(g, self, *args, **kwargs): - dim = kwargs.get("dim", None) - if dim is None and isinstance(args[0], numbers.Number): - dim = args[0] - if dim is not None: - keepdim = kwargs.get("keepdim", False) - # TODO: export it as ReduceMin +def min(g, self, dim_or_y, keepdim=None): + if keepdim is None: + return g.op("Min", self, dim_or_y) + else: + dim = _get_const(dim_or_y, 'i', 'dim') + keepdim = _get_const(keepdim, 'i', 'keepdim') + # TODO: export it as ReduceMax return g.op("ATen", self, operator_s="min", dim_i=dim, keepdim_i=keepdim, outputs=2) - else: - (other,) = args - return g.op("Min", self, other) def eq(g, self, other): @@ -692,6 +771,7 @@ def exp(g, self): return g.op("Exp", self) +@parse_args('v', 't', 'i', 'i') def norm(g, self, p, dim, keepdim): if p == 1: f = _reduce_op_symbolic("ReduceL1") @@ -702,10 +782,12 @@ def norm(g, self, p, dim, keepdim): return f(g, self, dim=dim, keepdim=keepdim) +@parse_args('v', 'v', 'v', 'i') def conv_tbc(g, input, weight, bias, pad): return g.op("ATen", input, weight, bias, operator_s="conv_tbc", pad_i=pad) +@parse_args('v', 'i', 'i') def _unique(g, input, sorted, return_inverse): return g.op("ATen", input, operator_s="_unique", sorted_i=sorted, return_inverse_i=return_inverse, outputs=2) @@ -746,7 +828,7 @@ def _cast_func_template(to_i, g, input, non_blocking): for k, v in cast_pytorch_to_onnx.items(): name = '_cast_{}'.format(k) - globals()[name] = partial(_cast_func_template, v) + globals()[name] = parse_args('v', 'i')(partial(_cast_func_template, v)) def zeros_like(g, input): @@ -755,15 +837,17 @@ def zeros_like(g, input): def full_like(g, input, fill_value): # TODO: a more efficient implementation (ConstantFill?) - return add(g, zeros_like(g, input), fill_value, alpha=torch.tensor(1)) + return add(g, zeros_like(g, input), fill_value, g.op("Constant", value_t=torch.tensor(1))) +@parse_args('v', 'i', 'i', 'i', 'i') def slice(g, self, dim, start, end, step): if step != 1: _unimplemented("slice", "step!=1 is currently not supported") return g.op("Slice", self, axes_i=[dim], starts_i=[start], ends_i=[end]) +@parse_args('v', 'f', 'f') def hardtanh(g, self, min_val, max_val): return g.op("Clip", self, min_f=min_val, max_f=max_val) @@ -772,11 +856,13 @@ def alias(g, self): return self +@parse_args('v', 'i') def unsqueeze(g, self, dim): return g.op("Unsqueeze", self, axes_i=[dim]) -def topk(g, self, k, dim=None, largest=True, sorted=True, out=None): +@parse_args('v', 'i', 'i', 'i', 'i') +def topk(g, self, k, dim, largest, sorted, out=None): if out is not None: _unimplemented("TopK", "Out parameter is not supported for topk") if not largest: @@ -785,6 +871,7 @@ def topk(g, self, k, dim=None, largest=True, sorted=True, out=None): return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2) +@parse_args('v', 'is') def repeat(g, self, repeats): if self.isTensor(): sizes = self.type().sizes() @@ -1041,5 +1128,6 @@ def retrieve_state(x, start, end): return symbolic +@parse_args('v', 'i') def _dim_arange(g, like, dim): return g.op('ATen', like, dim_i=dim, operator_s='_dim_arange') diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 59d567f461789b..7ce8220ff72b45 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -94,10 +94,23 @@ def export(model, args, f, export_params=True, verbose=False, training=False, operator_export_type=operator_export_type) -def _optimize_graph(graph, operator_export_type): +def _list_constant_prop(g, block): + for node in block.nodes(): + for subblock in node.blocks(): + _list_constant_prop(g, subblock) + if node.kind() == "prim::ListConstruct": + input_nodes = [i.node() for i in node.inputs()] + if all(inode.kind() == "prim::Constant" and inode.kindOf("value") == "i" for inode in input_nodes): + input_values = [inode['value'] for inode in input_nodes] + const_node = g.create("prim::Constant") + const_node.insertBefore(node) + const_node.is_("value", input_values) + const_node.output().setType(torch._C.ListType.ofInts()) + node.output().replaceAllUsesWith(const_node.output()) - # onnx only supports tensors, so we turn all out number types into tensors - torch._C._jit_pass_erase_number_types(graph) + +def _optimize_graph(graph, operator_export_type): + _list_constant_prop(graph, graph) # run dce to eliminate dead parts of the graph that might have been # left behind by things like symbolic_override @@ -106,6 +119,11 @@ def _optimize_graph(graph, operator_export_type): torch._C._jit_pass_peephole(graph) torch._C._jit_pass_lint(graph) + + # onnx only supports tensors, so we turn all out number types into tensors + torch._C._jit_pass_erase_number_types(graph) + torch._C._jit_pass_peephole(graph) + if operator_export_type != OperatorExportTypes.RAW: graph = torch._C._jit_pass_onnx(graph, operator_export_type) torch._C._jit_pass_lint(graph) @@ -452,7 +470,14 @@ def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExpor elif ns == "prim": if op_name == "Constant": - return g.op("Constant", value_t=n["value"]) + if n.kindOf("value") == "t": + return g.op("Constant", value_t=n["value"]) + elif n.kindOf("value") == "is": + value = torch.stack([torch.tensor(v) for v in n["value"]]) if n["value"] else [] + return g.op("Constant", value_t=value) + else: + raise RuntimeError("Unsupported prim::Constant kind: `{}`. Send a bug report.".format( + n.kindOf("value"))) elif op_name == "ListConstruct": unsqueezed = [g.op("Unsqueeze", input, axes_i=[0]) for input in inputs] return g.op("Concat", *unsqueezed, axis_i=0) From 4a192bcc3d297d5934dce2dc46cc11ff6b4baf2c Mon Sep 17 00:00:00 2001 From: Junjie Bai Date: Thu, 26 Jul 2018 23:26:07 -0700 Subject: [PATCH 15/17] Rename onnx integration tests file to avoid confusion Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9913 Differential Revision: D9026787 Pulled By: bddppq fbshipit-source-id: a3e7e79973abc4f5fe163f3e86b24382a1efd082 --- test/onnx/{test_caffe2.py => test_pytorch_onnx_caffe2.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/onnx/{test_caffe2.py => test_pytorch_onnx_caffe2.py} (100%) diff --git a/test/onnx/test_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py similarity index 100% rename from test/onnx/test_caffe2.py rename to test/onnx/test_pytorch_onnx_caffe2.py From a709f232257450fd899ec829164bc488a0354b63 Mon Sep 17 00:00:00 2001 From: tomguluson92 <314913739@qq.com> Date: Fri, 27 Jul 2018 00:45:42 -0700 Subject: [PATCH 16/17] revise a little spell mistake in tensor.py (#9868) Summary: Hello! I just find a small spell mistake while reading this source code. Just PR it, Thx! Pull Request resolved: https://github.com/pytorch/pytorch/pull/9868 Reviewed By: gchanan, ezyang Differential Revision: D9016030 Pulled By: soumith fbshipit-source-id: fc3877177be080adbdbda99a169e401691292ebb --- torch/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/tensor.py b/torch/tensor.py index b9a35e39ae6952..67e195466e8781 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -242,7 +242,7 @@ def btrifact(self, info=None, pivot=True): "consider using btrifact_with_info instead", stacklevel=2) factorization, pivots, _info = super(Tensor, self).btrifact_with_info(pivot=pivot) if info.type() != _info.type(): - raise ValueError('btrifact expects info to be an IntTenor') + raise ValueError('btrifact expects info to be an IntTensor') info.resize_as_(_info).copy_(_info) return factorization, pivots else: From 7b375ed362c4da833c38691b6bf3ade0941d3bf1 Mon Sep 17 00:00:00 2001 From: Changmao Cheng Date: Fri, 27 Jul 2018 00:54:51 -0700 Subject: [PATCH 17/17] fix ParameterDict doc Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9918 Differential Revision: D9026402 Pulled By: soumith fbshipit-source-id: d0459dcda631e8921ab39725b9045e03960da5c9 --- torch/nn/modules/container.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index d5f08aa10273fb..454151afed8201 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -411,7 +411,7 @@ class ParameterDict(Module): class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() - self.choices = nn.ParameterDict({ + self.params = nn.ParameterDict({ 'left': nn.Parameter(torch.randn(5, 10)), 'right': nn.Parameter(torch.randn(5, 10)) })