diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index 6cd8722316297e..be06656a3dee7c 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -17,7 +17,7 @@ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ } \ }() @@ -27,9 +27,9 @@ switch (the_type.scalarType()) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, Half, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ } \ }() @@ -43,7 +43,7 @@ AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ } \ }() @@ -59,7 +59,7 @@ AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ } \ }() @@ -74,8 +74,8 @@ AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, Half, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ default: \ - AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ + AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \ } \ }() diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 1ddac71cf299b0..d6ebbd4573a70c 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -581,6 +581,9 @@ Tensor hamming_window( double beta, const TensorOptions& options) { window_function_checks("hamming_window", options, window_length); + if (window_length == 0) { + return native::empty({0}, options); + } if (window_length == 1) { return native::ones({1}, options); } diff --git a/binaries/benchmark_helper.cc b/binaries/benchmark_helper.cc index 52b51174cf34d1..27a593aaa81963 100644 --- a/binaries/benchmark_helper.cc +++ b/binaries/benchmark_helper.cc @@ -215,7 +215,8 @@ void runNetwork( const bool wipe_cache, const bool run_individual, const int warmup, - const int iter) { + const int iter, + const int sleep_before_run) { if (!net_def.has_name()) { net_def.set_name("benchmark"); } @@ -234,6 +235,9 @@ void runNetwork( if (wipe_cache) { caffe2::wipe_cache(); } + if (sleep_before_run > 0) { + sleep(sleep_before_run); + } LOG(INFO) << "Main runs."; CAFFE_ENFORCE( iter >= 0, diff --git a/binaries/benchmark_helper.h b/binaries/benchmark_helper.h index 0a52e16a50079c..5af2d91cec4bc7 100644 --- a/binaries/benchmark_helper.h +++ b/binaries/benchmark_helper.h @@ -96,4 +96,5 @@ void runNetwork( const bool, const bool, const int, + const int, const int); diff --git a/binaries/caffe2_benchmark.cc b/binaries/caffe2_benchmark.cc index 729479a17c7598..230210644947cd 100644 --- a/binaries/caffe2_benchmark.cc +++ b/binaries/caffe2_benchmark.cc @@ -62,6 +62,10 @@ CAFFE2_DEFINE_bool( run_individual, false, "Whether to benchmark individual operators."); +CAFFE2_DEFINE_int( + sleep_before_run, + 0, + "The seconds to sleep before starting the benchmarking."); CAFFE2_DEFINE_bool( text_output, false, @@ -115,7 +119,8 @@ int main(int argc, char** argv) { caffe2::FLAGS_wipe_cache, caffe2::FLAGS_run_individual, caffe2::FLAGS_warmup, - caffe2::FLAGS_iter); + caffe2::FLAGS_iter, + caffe2::FLAGS_sleep_before_run); writeOutput( workspace, diff --git a/binaries/predictor_verifier.cc b/binaries/predictor_verifier.cc index e82a8e9d2cec85..e8e29f29559cee 100644 --- a/binaries/predictor_verifier.cc +++ b/binaries/predictor_verifier.cc @@ -16,7 +16,7 @@ #include "caffe2/core/flags.h" #include "caffe2/core/init.h" -#include "caffe2/core/predictor.h" +#include "caffe2/predictor/predictor.h" #include "caffe2/utils/proto_utils.h" CAFFE2_DEFINE_string(init_net, "", "The given path to the init protobuffer."); diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index ff096660945ecb..0d84ccbfb606a1 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -65,6 +65,7 @@ if(BUILD_CAFFE2) add_subdirectory(proto) add_subdirectory(contrib) add_subdirectory(core) + add_subdirectory(predictor) add_subdirectory(core/nomnigraph) add_subdirectory(core/dispatch) if (USE_NVRTC) diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index 325ccd3761afb3..757019ef64f07a 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -533,10 +533,11 @@ class Operator : public OperatorBase { return fillers; } -#define DISABLE_INPUT_FILLERS(Context) \ - std::vector> InputFillers( \ - const std::vector>& /* unused */) override { \ - throw UnsupportedOperatorFeature("Op does not have input fillers"); \ +#define DISABLE_INPUT_FILLERS(Context) \ + std::vector> InputFillers( \ + const std::vector>& /* unused */) override { \ + throw UnsupportedOperatorFeature( \ + OperatorBase::type() + " does not have input fillers"); \ } void SparseLengthsFillerHelper( @@ -554,7 +555,8 @@ class Operator : public OperatorBase { size_t segment_index, std::vector>* fillers) { CAFFE_ENFORCE_EQ(shapes[segment_index].size(), 1); - // TODO: what would be a proper #segments + // TODO (mnaumov): distribution of value + (*fillers)[value_index].Min(0).Max(shapes[value_index].front() * 2); (*fillers)[segment_index].SparseSegments(shapes[value_index].front() - 1); } diff --git a/caffe2/ideep/operators/operator_fallback_ideep.cc b/caffe2/ideep/operators/operator_fallback_ideep.cc index 0d8b6fd55b205b..16df4962b4284c 100644 --- a/caffe2/ideep/operators/operator_fallback_ideep.cc +++ b/caffe2/ideep/operators/operator_fallback_ideep.cc @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include #include #include @@ -112,4 +114,12 @@ REGISTER_IDEEP_OPERATOR( PRelu, IDEEPFallbackOp>); +// ctc decoder operators +REGISTER_IDEEP_OPERATOR( + CTCGreedyDecoder, + IDEEPFallbackOp>); +REGISTER_IDEEP_OPERATOR( + CTCBeamSearchDecoder, + IDEEPFallbackOp>); + } // namespace caffe2 diff --git a/caffe2/image/image_input_op.h b/caffe2/image/image_input_op.h index a8c45ca87d46a1..6bf232977d92f2 100644 --- a/caffe2/image/image_input_op.h +++ b/caffe2/image/image_input_op.h @@ -658,8 +658,16 @@ bool ImageInputOp::GetImageAndLabelAndInfoFromDBValue( for (int j = 0; j < additional_output_proto.int64_data_size(); ++j) { additional_output[j] = additional_output_proto.int64_data(j); } - } - else { + } else if (additional_output_proto.data_type() == TensorProto::UINT8) { + uint8_t* additional_output = + prefetched_additional_outputs_[i].template mutable_data() + + item_id * additional_output_proto.int32_data_size(); + + for (int j = 0; j < additional_output_proto.int32_data_size(); ++j) { + additional_output[j] = + static_cast(additional_output_proto.int32_data(j)); + } + } else { LOG(FATAL) << "Unsupported output type."; } } @@ -1148,6 +1156,9 @@ bool ImageInputOp::Prefetch() { } else if ( additional_output_proto.data_type() == TensorProto::INT64) { prefetched_additional_outputs_[i].template mutable_data(); + } else if ( + additional_output_proto.data_type() == TensorProto::UINT8) { + prefetched_additional_outputs_[i].template mutable_data(); } else { LOG(FATAL) << "Unsupported output type."; } diff --git a/caffe2/mobile/contrib/ios/ios_caffe.cc b/caffe2/mobile/contrib/ios/ios_caffe.cc index 12e0e5598c6aa0..f1bcf4a5b1087a 100644 --- a/caffe2/mobile/contrib/ios/ios_caffe.cc +++ b/caffe2/mobile/contrib/ios/ios_caffe.cc @@ -1,8 +1,8 @@ #include "ios_caffe.h" -#include "caffe2/core/predictor.h" #include "caffe2/core/tensor.h" #include "caffe2/mobile/contrib/ios/ios_caffe_predictor.h" +#include "caffe2/predictor/predictor.h" Caffe2IOSPredictor* MakeCaffe2Predictor(const std::string& init_net_str, const std::string& predict_net_str, diff --git a/caffe2/mobile/contrib/ios/ios_caffe.h b/caffe2/mobile/contrib/ios/ios_caffe.h index 3fbd235a74f706..7b5f8170405b6f 100644 --- a/caffe2/mobile/contrib/ios/ios_caffe.h +++ b/caffe2/mobile/contrib/ios/ios_caffe.h @@ -3,9 +3,9 @@ #include #include -#include "caffe2/core/predictor.h" #include "caffe2/mobile/contrib/ios/ios_caffe_defines.h" #include "caffe2/mobile/contrib/ios/ios_caffe_predictor.h" +#include "caffe2/predictor/predictor.h" extern "C" { diff --git a/caffe2/mobile/contrib/ios/ios_caffe_predictor.h b/caffe2/mobile/contrib/ios/ios_caffe_predictor.h index 0b065d3c426956..a51711ce0558e5 100644 --- a/caffe2/mobile/contrib/ios/ios_caffe_predictor.h +++ b/caffe2/mobile/contrib/ios/ios_caffe_predictor.h @@ -3,8 +3,8 @@ #include #include "caffe2/core/net.h" -#include "caffe2/core/predictor.h" #include "caffe2/mobile/contrib/ios/ios_caffe_defines.h" +#include "caffe2/predictor/predictor.h" struct Tensor { std::vector dims; diff --git a/caffe2/mobile/contrib/opengl/core/GLPredictor.h b/caffe2/mobile/contrib/opengl/core/GLPredictor.h index 2806f8a0408293..24c319759bd7d1 100644 --- a/caffe2/mobile/contrib/opengl/core/GLPredictor.h +++ b/caffe2/mobile/contrib/opengl/core/GLPredictor.h @@ -3,7 +3,7 @@ #include "GLImage.h" #include "caffe2/core/net.h" -#include "caffe2/core/predictor.h" +#include "caffe2/predictor/predictor.h" namespace caffe2 { class GLPredictor : public Predictor { diff --git a/caffe2/mobile/contrib/opengl/core/rewrite_net.h b/caffe2/mobile/contrib/opengl/core/rewrite_net.h index c3c47d63f75065..d0bc921a8876ca 100644 --- a/caffe2/mobile/contrib/opengl/core/rewrite_net.h +++ b/caffe2/mobile/contrib/opengl/core/rewrite_net.h @@ -1,7 +1,7 @@ #pragma once #include "GLPredictor.h" -#include "caffe2/core/predictor.h" +#include "caffe2/predictor/predictor.h" namespace caffe2 { bool tryConvertToOpenGL(const NetDef& initNet, diff --git a/caffe2/onnx/backend_rep.h b/caffe2/onnx/backend_rep.h index 5fe503bbe7ad98..fb46d19d10ba43 100644 --- a/caffe2/onnx/backend_rep.h +++ b/caffe2/onnx/backend_rep.h @@ -1,6 +1,6 @@ #pragma once -#include "caffe2/core/predictor.h" +#include "caffe2/predictor/predictor.h" #include "caffe2/proto/caffe2.pb.h" #include diff --git a/caffe2/operators/lengths_reducer_fused_8bit_rowwise_ops.h b/caffe2/operators/lengths_reducer_fused_8bit_rowwise_ops.h index 7c42d522f2e71f..198e4d81f772a3 100644 --- a/caffe2/operators/lengths_reducer_fused_8bit_rowwise_ops.h +++ b/caffe2/operators/lengths_reducer_fused_8bit_rowwise_ops.h @@ -68,7 +68,21 @@ class SparseLengthsFused8BitRowwiseOp : public Operator { return true; } - USE_VALUE_KEY_LENGTH_INPUT_FILLERS(Context, DATA, INDICES, LENGTHS) + std::vector> InputFillers( + const std::vector>& shapes) override { + CAFFE_ENFORCE_EQ(shapes.size(), Operator::Inputs().size()); + auto fillers = Operator::InputFillers(shapes); + if (with_weights) { + // TODO: enable the fillers + throw UnsupportedOperatorFeature( + OperatorBase::type() + " does not have input fillers"); + } + Operator::SparseLengthsFillerHelper( + shapes, INDICES, LENGTHS, &fillers); + Operator::SparseSegmentsFillerHelper( + shapes, DATA, INDICES, &fillers); + return fillers; + } private: enum { diff --git a/caffe2/operators/lengths_reducer_ops.h b/caffe2/operators/lengths_reducer_ops.h index 505dad1b102de3..f96c379ae81624 100644 --- a/caffe2/operators/lengths_reducer_ops.h +++ b/caffe2/operators/lengths_reducer_ops.h @@ -92,7 +92,21 @@ class CPUSparseLengthsReductionOp : public Operator { return true; } - USE_VALUE_KEY_LENGTH_INPUT_FILLERS(CPUContext, DATA, INDICES, LENGTHS) + std::vector> InputFillers( + const std::vector>& shapes) override { + CAFFE_ENFORCE_EQ(shapes.size(), Operator::Inputs().size()); + auto fillers = Operator::InputFillers(shapes); + if (USE_WEIGHT) { + // TODO: enable the fillers + throw UnsupportedOperatorFeature( + OperatorBase::type() + " does not have input fillers"); + } + Operator::SparseLengthsFillerHelper( + shapes, INDICES, LENGTHS, &fillers); + Operator::SparseSegmentsFillerHelper( + shapes, DATA, INDICES, &fillers); + return fillers; + } private: enum { diff --git a/caffe2/operators/one_hot_ops.cc b/caffe2/operators/one_hot_ops.cc index bb8a1dbc774413..4465e1c744044f 100644 --- a/caffe2/operators/one_hot_ops.cc +++ b/caffe2/operators/one_hot_ops.cc @@ -172,6 +172,9 @@ class SegmentOneHotOp : public Operator { SegmentOneHotOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws) {} + // TODO: enable input filler + DISABLE_INPUT_FILLERS(CPUContext) + bool RunOnDevice() override { auto& lengths = Input(0); auto& indices = Input(1); diff --git a/caffe2/operators/one_hot_ops.h b/caffe2/operators/one_hot_ops.h index 1b48b69326f3e7..644b3e74dd978f 100644 --- a/caffe2/operators/one_hot_ops.h +++ b/caffe2/operators/one_hot_ops.h @@ -13,6 +13,9 @@ class OneHotOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; + // TODO: enable input filler + DISABLE_INPUT_FILLERS(Context) + OneHotOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws) {} @@ -58,6 +61,8 @@ class BatchOneHotOp final : public Operator { BatchOneHotOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws) {} + USE_VALUE_KEY_LENGTH_INPUT_FILLERS(Context, X, VALS, LENS) + bool RunOnDevice() override { return DispatchHelper>::call(this, Input(X)); } @@ -83,6 +88,9 @@ class BatchBucketOneHotOp final : public Operator { bool RunOnDevice() override; + // TODO: enable input filler + DISABLE_INPUT_FILLERS(Context) + protected: INPUT_TAGS(X, LENS, BOUNDARIES); OUTPUT_TAGS(ONE_HOT); diff --git a/caffe2/operators/order_switch_ops.cc b/caffe2/operators/order_switch_ops.cc index 11cc6dedc24f9f..7296f9a74afa51 100644 --- a/caffe2/operators/order_switch_ops.cc +++ b/caffe2/operators/order_switch_ops.cc @@ -10,16 +10,10 @@ bool NHWC2NCHWOp::RunOnDevice() { const int N = X.dim32(0), H = X.dim32(1), W = X.dim32(2), C = X.dim32(3); Y->Resize(N, C, H, W); const float* Xdata = X.data(); - float* Ydata = Y->mutable_data(); - for (int n = 0; n < N; ++n) { - for (int h = 0; h < H; ++h) { - for (int w = 0; w < W; ++w) { - for (int c = 0; c < C; ++c) { - Ydata[((n * C + c) * H + h) * W + w] = *(Xdata++); - } - } - } - } + float* Ydata = Y->template mutable_data(); + std::array dims = {N, H, W, C}; + std::array axes = {0, 3, 1, 2}; + math::Transpose(4, dims.data(), axes.data(), Xdata, Ydata, &context_); return true; } @@ -31,20 +25,13 @@ bool NCHW2NHWCOp::RunOnDevice() { const int N = X.dim32(0), C = X.dim32(1), H = X.dim32(2), W = X.dim32(3); Y->Resize(N, H, W, C); const float* Xdata = X.data(); - float* Ydata = Y->mutable_data(); - for (int n = 0; n < N; ++n) { - for (int c = 0; c < C; ++c) { - for (int h = 0; h < H; ++h) { - for (int w = 0; w < W; ++w) { - Ydata[((n * H + h) * W + w) * C + c] = *(Xdata++); - } - } - } - } + float* Ydata = Y->template mutable_data(); + std::array dims = {N, C, H, W}; + std::array axes = {0, 2, 3, 1}; + math::Transpose(4, dims.data(), axes.data(), Xdata, Ydata, &context_); return true; } - REGISTER_CPU_OPERATOR(NHWC2NCHW, NHWC2NCHWOp); REGISTER_CPU_OPERATOR(NCHW2NHWC, NCHW2NHWCOp); @@ -102,4 +89,4 @@ class GetNCHW2NHWCGradient : public GradientMakerBase { } }; REGISTER_GRADIENT(NCHW2NHWC, GetNCHW2NHWCGradient); -} // namespace caffe2 +} // namespace caffe2 diff --git a/caffe2/operators/slice_op.h b/caffe2/operators/slice_op.h index 12734a8e33df71..01eed59598a87c 100644 --- a/caffe2/operators/slice_op.h +++ b/caffe2/operators/slice_op.h @@ -212,6 +212,9 @@ class SliceOp : public Operator { return RunOnDeviceImpl(Input(0), Output(0)); } + // This cannot be enabled given the output dims depends on the input + DISABLE_INPUT_FILLERS(Context) + protected: bool RunOnDeviceImpl(const Tensor& data, Tensor* output) { if (InputSize() > 1) { diff --git a/caffe2/predictor/CMakeLists.txt b/caffe2/predictor/CMakeLists.txt new file mode 100644 index 00000000000000..1038a84af38da5 --- /dev/null +++ b/caffe2/predictor/CMakeLists.txt @@ -0,0 +1,13 @@ +set(Caffe2_PREDICTOR_CPU_SRC + "${CMAKE_CURRENT_SOURCE_DIR}/predictor.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/predictor_utils.cc" +) +set(Caffe2_PREDICTOR_CPU_TEST_SRC + "${CMAKE_CURRENT_SOURCE_DIR}/predictor_test.cc") + +# Common files that are always going to be included. +list(APPEND Caffe2_CPU_SRCS ${Caffe2_PREDICTOR_CPU_SRC}) +list(APPEND Caffe2_CPU_TEST_SRCS ${Caffe2_PREDICTOR_CPU_TEST_SRC}) + +set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE) +set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE) diff --git a/caffe2/core/predictor.cc b/caffe2/predictor/predictor.cc similarity index 99% rename from caffe2/core/predictor.cc rename to caffe2/predictor/predictor.cc index 2aaa7a2dac3a30..8c3001571d2f3c 100644 --- a/caffe2/core/predictor.cc +++ b/caffe2/predictor/predictor.cc @@ -1,4 +1,4 @@ -#include "caffe2/core/predictor.h" +#include "caffe2/predictor/predictor.h" #ifdef CAFFE2_OPTIMIZER #include "caffe2/opt/optimizer.h" #endif diff --git a/caffe2/core/predictor.h b/caffe2/predictor/predictor.h similarity index 97% rename from caffe2/core/predictor.h rename to caffe2/predictor/predictor.h index b56401a35da5c3..a3f05d7aacac89 100644 --- a/caffe2/core/predictor.h +++ b/caffe2/predictor/predictor.h @@ -2,8 +2,8 @@ #include #include "caffe2/core/net.h" -#include "caffe2/core/predictor_config.h" #include "caffe2/core/tensor.h" +#include "caffe2/predictor/predictor_config.h" #include "caffe2/proto/metanet.pb.h" #include "caffe2/proto/predictor_consts.pb.h" diff --git a/caffe2/core/predictor_config.h b/caffe2/predictor/predictor_config.h similarity index 100% rename from caffe2/core/predictor_config.h rename to caffe2/predictor/predictor_config.h diff --git a/caffe2/core/predictor_test.cc b/caffe2/predictor/predictor_test.cc similarity index 99% rename from caffe2/core/predictor_test.cc rename to caffe2/predictor/predictor_test.cc index a37dbbb9e8d39e..31a102aea8712b 100644 --- a/caffe2/core/predictor_test.cc +++ b/caffe2/predictor/predictor_test.cc @@ -1,7 +1,7 @@ #include "caffe2/core/context.h" #include "caffe2/core/operator.h" -#include "caffe2/core/predictor.h" #include "caffe2/core/tensor.h" +#include "caffe2/predictor/predictor.h" #include "caffe2/utils/math.h" #include diff --git a/caffe2/core/predictor_utils.cc b/caffe2/predictor/predictor_utils.cc similarity index 98% rename from caffe2/core/predictor_utils.cc rename to caffe2/predictor/predictor_utils.cc index dea0388fc12528..cc37eec85fbaa1 100644 --- a/caffe2/core/predictor_utils.cc +++ b/caffe2/predictor/predictor_utils.cc @@ -1,4 +1,4 @@ -#include "caffe2/core/predictor_utils.h" +#include "caffe2/predictor/predictor_utils.h" #include "caffe2/core/blob.h" #include "caffe2/core/logging.h" diff --git a/caffe2/core/predictor_utils.h b/caffe2/predictor/predictor_utils.h similarity index 100% rename from caffe2/core/predictor_utils.h rename to caffe2/predictor/predictor_utils.h diff --git a/caffe2/python/operator_test/order_switch_test.py b/caffe2/python/operator_test/order_switch_test.py new file mode 100644 index 00000000000000..71ba64e40f3ffb --- /dev/null +++ b/caffe2/python/operator_test/order_switch_test.py @@ -0,0 +1,52 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np +import caffe2.python.hypothesis_test_util as hu +from caffe2.python import core +from hypothesis import given +import hypothesis.strategies as st + + +class OrderSwitchOpsTest(hu.HypothesisTestCase): + @given( + n=st.integers(1, 5), + c=st.integers(1, 5), + h=st.integers(1, 5), + w=st.integers(1, 5), + **hu.gcs) + def test_nchw2nhwc(self, n, c, h, w, gc, dc): + X = np.random.randn(n, c, h, w).astype(np.float32) + + op = core.CreateOperator("NCHW2NHWC", ["X"], ["Y"], + device_option=gc) + + def nchw2nhwc_ref(X): + X_reshaped = X.transpose((0, 2, 3, 1)) + return (X_reshaped,) + + self.assertReferenceChecks(gc, op, [X], nchw2nhwc_ref) + self.assertGradientChecks(gc, op, [X], 0, [0]) + self.assertDeviceChecks(dc, op, [X], [0]) + + @given( + n=st.integers(1, 5), + c=st.integers(1, 5), + h=st.integers(1, 5), + w=st.integers(1, 5), + **hu.gcs) + def test_nhwc2nchw(self, n, c, h, w, gc, dc): + X = np.random.randn(n, h, w, c).astype(np.float32) + + op = core.CreateOperator("NHWC2NCHW", ["X"], ["Y"], + device_option=gc) + + def nhwc2nchw_ref(X): + X_reshaped = X.transpose((0, 3, 1, 2)) + return (X_reshaped,) + + self.assertReferenceChecks(gc, op, [X], nhwc2nchw_ref) + self.assertGradientChecks(gc, op, [X], 0, [0]) + self.assertDeviceChecks(dc, op, [X], [0]) diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index 04df247d821daf..9256896bd8d138 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -12,7 +12,6 @@ #include "caffe2/core/db.h" #include "caffe2/core/numa.h" #include "caffe2/core/operator.h" -#include "caffe2/core/predictor.h" #include "caffe2/core/stats.h" #include "caffe2/core/transform.h" #include "caffe2/mkl/mkl_utils.h" @@ -28,6 +27,7 @@ #include "caffe2/opt/optimize_ideep.h" #include "caffe2/opt/passes.h" #include "caffe2/opt/sink.h" +#include "caffe2/predictor/predictor.h" #include "caffe2/utils/cpuid.h" #include "caffe2/utils/string_utils.h" diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 05909a692b2a5b..c3c85797b4cd82 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -224,6 +224,7 @@ view of a storage and defines numeric operations on it. .. automethod:: expand_as .. automethod:: exponential_ .. automethod:: fill_ + .. automethod:: flatten .. automethod:: flip .. automethod:: float .. automethod:: floor diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 3ee7d6e7abe68c..c1e914c03c74e7 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -259,6 +259,7 @@ Other Operations .. autofunction:: diagflat .. autofunction:: diagonal .. autofunction:: einsum +.. autofunction:: flatten .. autofunction:: flip .. autofunction:: histc .. autofunction:: meshgrid diff --git a/test/expect/TestJit.test_alexnet.expect b/test/expect/TestJit.test_alexnet.expect index 9a1105c8b8b176..09d274f58be622 100644 --- a/test/expect/TestJit.test_alexnet.expect +++ b/test/expect/TestJit.test_alexnet.expect @@ -15,38 +15,51 @@ graph(%0 : Double(1, 3, 224, 224) %14 : Double(4096) %15 : Double(1000, 4096) %16 : Double(1000)) { - %17 : Double(1, 64, 55, 55) = aten::_convolution[stride=[4, 4], padding=[2, 2], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%0, %1, %2), scope: AlexNet/Sequential[features]/Conv2d[0] - %18 : Double(1, 64, 55, 55) = aten::threshold[threshold={0}, value={0}](%17), scope: AlexNet/Sequential[features]/ReLU[1] - %19 : Double(1, 64, 27, 27), %20 : Long(1, 64, 27, 27) = aten::max_pool2d_with_indices[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2] - %21 : Double(1, 192, 27, 27) = aten::_convolution[stride=[1, 1], padding=[2, 2], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%19, %3, %4), scope: AlexNet/Sequential[features]/Conv2d[3] - %22 : Double(1, 192, 27, 27) = aten::threshold[threshold={0}, value={0}](%21), scope: AlexNet/Sequential[features]/ReLU[4] - %23 : Double(1, 192, 13, 13), %24 : Long(1, 192, 13, 13) = aten::max_pool2d_with_indices[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%22), scope: AlexNet/Sequential[features]/MaxPool2d[5] - %25 : Double(1, 384, 13, 13) = aten::_convolution[stride=[1, 1], padding=[1, 1], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%23, %5, %6), scope: AlexNet/Sequential[features]/Conv2d[6] - %26 : Double(1, 384, 13, 13) = aten::threshold[threshold={0}, value={0}](%25), scope: AlexNet/Sequential[features]/ReLU[7] - %27 : Double(1, 256, 13, 13) = aten::_convolution[stride=[1, 1], padding=[1, 1], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%26, %7, %8), scope: AlexNet/Sequential[features]/Conv2d[8] - %28 : Double(1, 256, 13, 13) = aten::threshold[threshold={0}, value={0}](%27), scope: AlexNet/Sequential[features]/ReLU[9] - %29 : Double(1, 256, 13, 13) = aten::_convolution[stride=[1, 1], padding=[1, 1], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%28, %9, %10), scope: AlexNet/Sequential[features]/Conv2d[10] - %30 : Double(1, 256, 13, 13) = aten::threshold[threshold={0}, value={0}](%29), scope: AlexNet/Sequential[features]/ReLU[11] - %31 : Double(1, 256, 6, 6), %32 : Long(1, 256, 6, 6) = aten::max_pool2d_with_indices[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%30), scope: AlexNet/Sequential[features]/MaxPool2d[12] - %33 : int = prim::Constant[value=0](), scope: AlexNet - %34 : int = aten::size(%31, %33), scope: AlexNet - %35 : Long() = prim::NumToTensor(%34), scope: AlexNet - %36 : int = prim::TensorToNum(%35), scope: AlexNet - %37 : int = prim::Constant[value=9216](), scope: AlexNet - %38 : int[] = prim::ListConstruct(%36, %37), scope: AlexNet - %39 : Double(1, 9216) = aten::view(%31, %38), scope: AlexNet - %40 : Double(1, 9216) = ^Dropout(0.5, True, False)(%39), scope: AlexNet/Sequential[classifier]/Dropout[0] - %41 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1] - %42 : Double(1, 4096) = aten::expand[size=[1, 4096], implicit=1](%12), scope: AlexNet/Sequential[classifier]/Linear[1] - %43 : Double(1, 4096) = aten::addmm[beta={1}, alpha={1}](%42, %40, %41), scope: AlexNet/Sequential[classifier]/Linear[1] - %44 : Double(1, 4096) = aten::threshold[threshold={0}, value={0}](%43), scope: AlexNet/Sequential[classifier]/ReLU[2] - %45 : Double(1, 4096) = ^Dropout(0.5, True, False)(%44), scope: AlexNet/Sequential[classifier]/Dropout[3] - %46 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4] - %47 : Double(1, 4096) = aten::expand[size=[1, 4096], implicit=1](%14), scope: AlexNet/Sequential[classifier]/Linear[4] - %48 : Double(1, 4096) = aten::addmm[beta={1}, alpha={1}](%47, %45, %46), scope: AlexNet/Sequential[classifier]/Linear[4] - %49 : Double(1, 4096) = aten::threshold[threshold={0}, value={0}](%48), scope: AlexNet/Sequential[classifier]/ReLU[5] - %50 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6] - %51 : Double(1, 1000) = aten::expand[size=[1, 1000], implicit=1](%16), scope: AlexNet/Sequential[classifier]/Linear[6] - %52 : Double(1, 1000) = aten::addmm[beta={1}, alpha={1}](%51, %49, %50), scope: AlexNet/Sequential[classifier]/Linear[6] - return (%52); + %17 : int = prim::Constant[value=4](), scope: AlexNet/Sequential[features]/Conv2d[0] + %18 : int[] = prim::ListConstruct(%17, %17), scope: AlexNet/Sequential[features]/Conv2d[0] + %19 : int = prim::Constant[value=2](), scope: AlexNet/Sequential[features]/Conv2d[0] + %20 : int[] = prim::ListConstruct(%19, %19), scope: AlexNet/Sequential[features]/Conv2d[0] + %21 : int = prim::Constant[value=1](), scope: AlexNet/Sequential[features]/Conv2d[0] + %22 : int[] = prim::ListConstruct(%21, %21), scope: AlexNet/Sequential[features]/Conv2d[0] + %23 : int = prim::Constant[value=0](), scope: AlexNet/Sequential[features]/Conv2d[0] + %24 : int[] = prim::ListConstruct(%23, %23), scope: AlexNet/Sequential[features]/Conv2d[0] + %25 : Double(1, 64, 55, 55) = aten::_convolution(%0, %1, %2, %18, %20, %22, %23, %24, %21, %23, %23, %21), scope: AlexNet/Sequential[features]/Conv2d[0] + %26 : Double(1, 64, 55, 55) = aten::threshold(%25, %23, %23), scope: AlexNet/Sequential[features]/ReLU[1] + %27 : int = prim::Constant[value=3](), scope: AlexNet/Sequential[features]/MaxPool2d[2] + %28 : int[] = prim::ListConstruct(%27, %27), scope: AlexNet/Sequential[features]/MaxPool2d[2] + %29 : Double(1, 64, 27, 27), %30 : Long(1, 64, 27, 27) = aten::max_pool2d_with_indices(%26, %28, %20, %24, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[2] + %31 : Double(1, 192, 27, 27) = aten::_convolution(%29, %3, %4, %22, %20, %22, %23, %24, %21, %23, %23, %21), scope: AlexNet/Sequential[features]/Conv2d[3] + %32 : Double(1, 192, 27, 27) = aten::threshold(%31, %23, %23), scope: AlexNet/Sequential[features]/ReLU[4] + %33 : Double(1, 192, 13, 13), %34 : Long(1, 192, 13, 13) = aten::max_pool2d_with_indices(%32, %28, %20, %24, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[5] + %35 : Double(1, 384, 13, 13) = aten::_convolution(%33, %5, %6, %22, %22, %22, %23, %24, %21, %23, %23, %21), scope: AlexNet/Sequential[features]/Conv2d[6] + %36 : Double(1, 384, 13, 13) = aten::threshold(%35, %23, %23), scope: AlexNet/Sequential[features]/ReLU[7] + %37 : Double(1, 256, 13, 13) = aten::_convolution(%36, %7, %8, %22, %22, %22, %23, %24, %21, %23, %23, %21), scope: AlexNet/Sequential[features]/Conv2d[8] + %38 : Double(1, 256, 13, 13) = aten::threshold(%37, %23, %23), scope: AlexNet/Sequential[features]/ReLU[9] + %39 : Double(1, 256, 13, 13) = aten::_convolution(%38, %9, %10, %22, %22, %22, %23, %24, %21, %23, %23, %21), scope: AlexNet/Sequential[features]/Conv2d[10] + %40 : Double(1, 256, 13, 13) = aten::threshold(%39, %23, %23), scope: AlexNet/Sequential[features]/ReLU[11] + %41 : Double(1, 256, 6, 6), %42 : Long(1, 256, 6, 6) = aten::max_pool2d_with_indices(%40, %28, %20, %24, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[12] + %43 : int = aten::size(%41, %23), scope: AlexNet + %44 : Long() = prim::NumToTensor(%43), scope: AlexNet + %45 : int = prim::TensorToNum(%44), scope: AlexNet + %46 : int = prim::Constant[value=9216](), scope: AlexNet + %47 : int[] = prim::ListConstruct(%45, %46), scope: AlexNet + %48 : Double(1, 9216) = aten::view(%41, %47), scope: AlexNet + %49 : Double(1, 9216) = ^Dropout(0.5, True, False)(%48), scope: AlexNet/Sequential[classifier]/Dropout[0] + %50 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1] + %51 : int = prim::Constant[value=4096](), scope: AlexNet/Sequential[classifier]/Linear[1] + %52 : int[] = prim::ListConstruct(%21, %51), scope: AlexNet/Sequential[classifier]/Linear[1] + %53 : Double(1, 4096) = aten::expand(%12, %52, %21), scope: AlexNet/Sequential[classifier]/Linear[1] + %54 : Double(1, 4096) = aten::addmm(%53, %49, %50, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[1] + %55 : Double(1, 4096) = aten::threshold(%54, %23, %23), scope: AlexNet/Sequential[classifier]/ReLU[2] + %56 : Double(1, 4096) = ^Dropout(0.5, True, False)(%55), scope: AlexNet/Sequential[classifier]/Dropout[3] + %57 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4] + %58 : Double(1, 4096) = aten::expand(%14, %52, %21), scope: AlexNet/Sequential[classifier]/Linear[4] + %59 : Double(1, 4096) = aten::addmm(%58, %56, %57, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[4] + %60 : Double(1, 4096) = aten::threshold(%59, %23, %23), scope: AlexNet/Sequential[classifier]/ReLU[5] + %61 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6] + %62 : int = prim::Constant[value=1000](), scope: AlexNet/Sequential[classifier]/Linear[6] + %63 : int[] = prim::ListConstruct(%21, %62), scope: AlexNet/Sequential[classifier]/Linear[6] + %64 : Double(1, 1000) = aten::expand(%16, %63, %21), scope: AlexNet/Sequential[classifier]/Linear[6] + %65 : Double(1, 1000) = aten::addmm(%64, %60, %61, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[6] + return (%65); } diff --git a/test/expect/TestJit.test_batchnorm.expect b/test/expect/TestJit.test_batchnorm.expect index 4fa8a72a43ae7f..c61390578d45b8 100644 --- a/test/expect/TestJit.test_batchnorm.expect +++ b/test/expect/TestJit.test_batchnorm.expect @@ -4,6 +4,10 @@ graph(%0 : Double(2, 2, 2, 2) %3 : Double(2) %4 : Double(2) %5 : Long()) { - %6 : Double(2, 2, 2, 2) = aten::batch_norm[training=1, momentum=0.1, eps=1e-05, cudnn_enabled=1](%0, %1, %2, %3, %4), scope: BatchNorm2d - return (%6); + %6 : int = prim::Constant[value=1](), scope: BatchNorm2d + %7 : float = prim::Constant[value=0.1](), scope: BatchNorm2d + %8 : float = prim::Constant[value=1e-05](), scope: BatchNorm2d + %9 : int = prim::Constant[value=1](), scope: BatchNorm2d + %10 : Double(2, 2, 2, 2) = aten::batch_norm(%0, %1, %2, %3, %4, %6, %7, %8, %9), scope: BatchNorm2d + return (%10); } diff --git a/test/expect/TestJit.test_concat_fusion.expect b/test/expect/TestJit.test_concat_fusion.expect index c1b45b172745ba..027c2de33e5926 100644 --- a/test/expect/TestJit.test_concat_fusion.expect +++ b/test/expect/TestJit.test_concat_fusion.expect @@ -3,10 +3,12 @@ graph(%0 : Float(3, 20) %2 : Float(6, 20) = prim::FusionGroup_0[device=0](%0, %1) return (%2); } -with prim::FusionGroup_0 = graph(%3 : Float(3, 20) - %4 : Float(3, 20)) { - %6 : Float(3, 20) = aten::add[alpha={1}](%3, %4) - %5 : Float(3, 20) = aten::mul(%3, %4) - %2 : Float(6, 20) = aten::cat[dim=0](%6, %5) - return (%2); +with prim::FusionGroup_0 = graph(%4 : Float(3, 20) + %5 : Float(3, 20)) { + %7 : int = prim::Constant[value=1]() + %8 : Float(3, 20) = aten::add(%4, %5, %7) + %6 : Float(3, 20) = aten::mul(%4, %5) + %2 : int = prim::Constant[value=0]() + %3 : Float(6, 20) = aten::cat(%8, %6, %2) + return (%3); } diff --git a/test/expect/TestJit.test_conv.expect b/test/expect/TestJit.test_conv.expect index 584f807a8ca071..fcb53bad1425dc 100644 --- a/test/expect/TestJit.test_conv.expect +++ b/test/expect/TestJit.test_conv.expect @@ -1,6 +1,23 @@ graph(%0 : Double(20, 16, 50, 40) %1 : Double(13, 16, 3, 3)) { %2 : Dynamic = prim::Undefined(), scope: Conv2d - %3 : Double(20, 13, 48, 38) = aten::_convolution[stride=[1, 1], padding=[0, 0], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%0, %1, %2), scope: Conv2d - return (%3); + %3 : int = prim::Constant[value=1](), scope: Conv2d + %4 : int = prim::Constant[value=1](), scope: Conv2d + %5 : int[] = prim::ListConstruct(%3, %4), scope: Conv2d + %6 : int = prim::Constant[value=0](), scope: Conv2d + %7 : int = prim::Constant[value=0](), scope: Conv2d + %8 : int[] = prim::ListConstruct(%6, %7), scope: Conv2d + %9 : int = prim::Constant[value=1](), scope: Conv2d + %10 : int = prim::Constant[value=1](), scope: Conv2d + %11 : int[] = prim::ListConstruct(%9, %10), scope: Conv2d + %12 : int = prim::Constant[value=0](), scope: Conv2d + %13 : int = prim::Constant[value=0](), scope: Conv2d + %14 : int = prim::Constant[value=0](), scope: Conv2d + %15 : int[] = prim::ListConstruct(%13, %14), scope: Conv2d + %16 : int = prim::Constant[value=1](), scope: Conv2d + %17 : int = prim::Constant[value=0](), scope: Conv2d + %18 : int = prim::Constant[value=0](), scope: Conv2d + %19 : int = prim::Constant[value=1](), scope: Conv2d + %20 : Double(20, 13, 48, 38) = aten::_convolution(%0, %1, %2, %5, %8, %11, %12, %15, %16, %17, %18, %19), scope: Conv2d + return (%20); } diff --git a/test/expect/TestJit.test_cpp.expect b/test/expect/TestJit.test_cpp.expect index bfe49e45cfb618..f1f3a6a9c39012 100644 --- a/test/expect/TestJit.test_cpp.expect +++ b/test/expect/TestJit.test_cpp.expect @@ -2,47 +2,60 @@ testBlocks graph(%a : Dynamic %b : Dynamic %c : Dynamic) { - %2 : Dynamic = aten::add[alpha={1}](%a, %b) - %4 : Dynamic = prim::If(%c) + %2 : int = prim::Constant[value=1]() + %3 : Dynamic = aten::add(%a, %b, %2) + %5 : Dynamic = prim::If(%c) block0() { - %5 : Dynamic = aten::add[alpha={1}](%2, %2) - -> (%5) + %6 : int = prim::Constant[value=1]() + %7 : Dynamic = aten::add(%3, %3, %6) + -> (%7) } block1() { - %6 : Dynamic = aten::add[alpha={1}](%b, %2) - %7 : Dynamic = aten::add[alpha={1}](%6, %2) - -> (%7) + %8 : int = prim::Constant[value=1]() + %9 : Dynamic = aten::add(%b, %3, %8) + %10 : int = prim::Constant[value=1]() + %11 : Dynamic = aten::add(%9, %3, %10) + -> (%11) } - %8 : Dynamic = aten::add[alpha={1}](%4, %2) - return (%8); + %12 : int = prim::Constant[value=1]() + %13 : Dynamic = aten::add(%5, %3, %12) + return (%13); } graph(%a : Dynamic %b : Dynamic %c : Dynamic) { - %2 : Dynamic = aten::add[alpha={1}](%a, %b) - %4 : Dynamic = prim::If(%c) + %2 : int = prim::Constant[value=1]() + %3 : Dynamic = aten::add(%a, %b, %2) + %5 : Dynamic = prim::If(%c) block0() { - %6 : Dynamic = aten::add[alpha={1}](%b, %2) - %7 : Dynamic = aten::add[alpha={1}](%6, %2) - -> (%7) + %8 : int = prim::Constant[value=1]() + %9 : Dynamic = aten::add(%b, %3, %8) + %10 : int = prim::Constant[value=1]() + %11 : Dynamic = aten::add(%9, %3, %10) + -> (%11) } - %8 : Dynamic = aten::add[alpha={1}](%4, %2) - return (%8); + %12 : int = prim::Constant[value=1]() + %13 : Dynamic = aten::add(%5, %3, %12) + return (%13); } graph(%a : Dynamic %b : Dynamic %c : Dynamic) { - %3 : Dynamic = aten::add[alpha={1}](%a, %b) - %4 : Dynamic = prim::If(%c) + %3 : int = prim::Constant[value=1]() + %4 : Dynamic = aten::add(%a, %b, %3) + %5 : Dynamic = prim::If(%c) block0() { - %5 : Dynamic = aten::add[alpha={1}](%b, %3) - %6 : Dynamic = aten::add[alpha={1}](%5, %3) - -> (%6) + %6 : int = prim::Constant[value=1]() + %7 : Dynamic = aten::add(%b, %4, %6) + %8 : int = prim::Constant[value=1]() + %9 : Dynamic = aten::add(%7, %4, %8) + -> (%9) } - %7 : Dynamic = aten::add[alpha={1}](%4, %3) - return (%7); + %10 : int = prim::Constant[value=1]() + %11 : Dynamic = aten::add(%5, %4, %10) + return (%11); } testCreateAutodiffSubgraphs @@ -51,28 +64,32 @@ graph(%0 : Dynamic %2 : Dynamic %3 : Dynamic %4 : Dynamic) { - %21 : Dynamic, %22 : Dynamic = prim::GraphExecutor_0(%0, %3, %1, %4, %2) - return (%22, %21); + %25 : Dynamic, %26 : Dynamic = prim::GraphExecutor_0(%0, %3, %1, %4, %2) + return (%26, %25); } with prim::GraphExecutor_0 = graph(%1 : Dynamic %2 : Dynamic %4 : Dynamic %5 : Dynamic - %16 : Dynamic) { + %19 : Dynamic) { %0 : Dynamic = aten::mm(%1, %2) %3 : Dynamic = aten::mm(%4, %5) - %6 : Dynamic = aten::add[alpha={1}](%0, %3) - %7 : Dynamic, %8 : Dynamic, %9 : Dynamic, %10 : Dynamic = aten::chunk[chunks=4, dim=1](%6) - %11 : Dynamic = aten::sigmoid(%7) - %12 : Dynamic = aten::sigmoid(%10) - %13 : Dynamic = aten::tanh(%9) - %14 : Dynamic = aten::sigmoid(%8) - %15 : Dynamic = aten::mul(%14, %16) - %17 : Dynamic = aten::mul(%11, %13) - %18 : Dynamic = aten::add[alpha={1}](%15, %17) - %19 : Dynamic = aten::tanh(%18) - %20 : Dynamic = aten::mul(%12, %19) - return (%18, %20); + %6 : int = prim::Constant[value=1]() + %7 : Dynamic = aten::add(%0, %3, %6) + %8 : int = prim::Constant[value=4]() + %9 : int = prim::Constant[value=1]() + %10 : Dynamic, %11 : Dynamic, %12 : Dynamic, %13 : Dynamic = aten::chunk(%7, %8, %9) + %14 : Dynamic = aten::sigmoid(%10) + %15 : Dynamic = aten::sigmoid(%13) + %16 : Dynamic = aten::tanh(%12) + %17 : Dynamic = aten::sigmoid(%11) + %18 : Dynamic = aten::mul(%17, %19) + %20 : Dynamic = aten::mul(%14, %16) + %21 : int = prim::Constant[value=1]() + %22 : Dynamic = aten::add(%18, %20, %21) + %23 : Dynamic = aten::tanh(%22) + %24 : Dynamic = aten::mul(%15, %23) + return (%22, %24); } testDifferentiate @@ -80,66 +97,75 @@ graph(%0 : Float(2, 3, 4) %1 : Float(2, 3, 4)) { %2 : Float(2, 3, 4) = aten::mul(%0, %1) %3 : Float(2, 3, 4) = aten::mul(%2, %0) - %4 : Float(2, 3, 4) = aten::add[alpha={1}](%3, %1) - return (%4, %2); + %4 : int = prim::Constant[value=1]() + %5 : Float(2, 3, 4) = aten::add(%3, %1, %4) + return (%5, %2); } graph(%0 : Float(2, 3, 4) %1 : Float(2, 3, 4) %2 : Float(2, 3, 4) %3 : Float(2, 3, 4) %4 : Float(2, 3, 4)) { - %5 : Float(2, 3, 4), %6 : Float(2, 3, 4) = prim::GradOf[name=aten::add](%0) + %5 : int = prim::Constant[value=1]() + %6 : Float(2, 3, 4), %7 : Float(2, 3, 4) = prim::GradOf[name=aten::add](%0) block0() { - -> (%0, %0) + %8 : Float(2, 3, 4) = aten::mul(%0, %5) + -> (%0, %8) } - %7 : Float(2, 3, 4), %8 : Float(2, 3, 4) = prim::GradOf[name=aten::mul](%5) + %9 : Float(2, 3, 4), %10 : Float(2, 3, 4) = prim::GradOf[name=aten::mul](%6) block0() { - %9 : Float(2, 3, 4) = aten::mul(%5, %2) - %10 : Float(2, 3, 4) = aten::mul(%5, %4) - -> (%9, %10) + %11 : Float(2, 3, 4) = aten::mul(%6, %2) + %12 : Float(2, 3, 4) = aten::mul(%6, %4) + -> (%11, %12) } - %11 : Dynamic = prim::AutogradAdd(%1, %7) - %12 : Float(2, 3, 4), %13 : Float(2, 3, 4) = prim::GradOf[name=aten::mul](%11) + %13 : Dynamic = prim::AutogradAdd(%1, %9) + %14 : Float(2, 3, 4), %15 : Float(2, 3, 4) = prim::GradOf[name=aten::mul](%13) block0() { - %14 : Float(2, 3, 4) = aten::mul(%11, %3) - %15 : Float(2, 3, 4) = aten::mul(%11, %2) - -> (%14, %15) + %16 : Float(2, 3, 4) = aten::mul(%13, %3) + %17 : Float(2, 3, 4) = aten::mul(%13, %2) + -> (%16, %17) } - %16 : Dynamic = prim::AutogradAdd(%8, %12) - %17 : Dynamic = prim::AutogradAdd(%6, %13) - return (%16, %17); + %18 : Dynamic = prim::AutogradAdd(%10, %14) + %19 : Dynamic = prim::AutogradAdd(%7, %15) + return (%18, %19); } testDifferentiateWithRequiresGrad graph(%0 : Float(2, 3, 4) %1 : Float(2, 3, 4)) { %2 : Float(2, 3, 4) = aten::mul(%1, %1) - %3 : Float(2, 3, 4) = aten::add[alpha={1}](%2, %1) - %4 : Float(2, 3, 4) = aten::add[alpha={1}](%3, %0) - %5 : Float(2, 3, 4) = aten::mul(%4, %0) - %6 : Float(2, 3, 4) = aten::add[alpha={1}](%5, %1) - return (%3, %6, %4); + %3 : int = prim::Constant[value=1]() + %4 : Float(2, 3, 4) = aten::add(%2, %1, %3) + %5 : int = prim::Constant[value=1]() + %6 : Float(2, 3, 4) = aten::add(%4, %0, %5) + %7 : Float(2, 3, 4) = aten::mul(%6, %0) + %8 : int = prim::Constant[value=1]() + %9 : Float(2, 3, 4) = aten::add(%7, %1, %8) + return (%4, %9, %6); } graph(%0 : Float(2, 3, 4) %1 : Float(2, 3, 4) %2 : Float(2, 3, 4) %3 : Float(2, 3, 4)) { - %4 : Float(2, 3, 4), %5 : Float(2, 3, 4) = prim::GradOf[name=aten::add](%0) + %4 : int = prim::Constant[value=1]() + %5 : Float(2, 3, 4), %6 : Float(2, 3, 4) = prim::GradOf[name=aten::add](%0) block0() { - -> (%0, %0) + %7 : Float(2, 3, 4) = aten::mul(%0, %4) + -> (%0, %7) } - %6 : Float(2, 3, 4), %7 : Float(2, 3, 4) = prim::GradOf[name=aten::mul](%4) + %8 : Float(2, 3, 4), %9 : Float(2, 3, 4) = prim::GradOf[name=aten::mul](%5) block0() { - %8 : Float(2, 3, 4) = aten::mul(%4, %2) - %9 : Float(2, 3, 4) = aten::mul(%4, %3) - -> (%8, %9) + %10 : Float(2, 3, 4) = aten::mul(%5, %2) + %11 : Float(2, 3, 4) = aten::mul(%5, %3) + -> (%10, %11) } - %10 : Dynamic = prim::AutogradAdd(%1, %6) - %11 : Float(2, 3, 4), %12 : Float(2, 3, 4) = prim::GradOf[name=aten::add](%10) + %12 : Dynamic = prim::AutogradAdd(%1, %8) + %13 : Float(2, 3, 4), %14 : Float(2, 3, 4) = prim::GradOf[name=aten::add](%12) block0() { - -> (%10, %10) + %15 : Float(2, 3, 4) = aten::mul(%12, %4) + -> (%12, %15) } - %13 : Dynamic = prim::AutogradAdd(%7, %12) - return (%13); + %16 : Dynamic = prim::AutogradAdd(%9, %14) + return (%16); } diff --git a/test/expect/TestJit.test_cse.expect b/test/expect/TestJit.test_cse.expect index b3d1a81a9929b8..46d9a4c6a17e0c 100644 --- a/test/expect/TestJit.test_cse.expect +++ b/test/expect/TestJit.test_cse.expect @@ -1,10 +1,11 @@ graph(%0 : Double(2) %1 : Double(2)) { - %2 : Double(2) = aten::add[alpha={1}](%0, %1) - %3 : Double(2) = aten::mul(%2, %2) - %4 : Double(2) = aten::mul(%3, %2) - %5 : Double(2) = aten::tanh(%4) - %6 : Double(2) = aten::add[alpha={1}](%5, %5) - %7 : Double(2) = aten::add[alpha={1}](%4, %6) - return (%7); + %2 : int = prim::Constant[value=1]() + %3 : Double(2) = aten::add(%0, %1, %2) + %4 : Double(2) = aten::mul(%3, %3) + %5 : Double(2) = aten::mul(%4, %3) + %6 : Double(2) = aten::tanh(%5) + %7 : Double(2) = aten::add(%6, %6, %2) + %8 : Double(2) = aten::add(%5, %7, %2) + return (%8); } diff --git a/test/expect/TestJit.test_decompose_addmm.expect b/test/expect/TestJit.test_decompose_addmm.expect index 925362f4f6a4ae..65a3e416d2b1e9 100644 --- a/test/expect/TestJit.test_decompose_addmm.expect +++ b/test/expect/TestJit.test_decompose_addmm.expect @@ -3,16 +3,23 @@ graph(%mat : Dynamic %mat2 : Dynamic %alpha : Dynamic %beta : Dynamic) { - %5 : Dynamic = aten::mm(%mat1, %mat2) - %6 : Dynamic = aten::add[alpha={1}](%mat, %5) - %7 : Dynamic = aten::mm(%mat1, %mat2) - %8 : Dynamic = aten::add[alpha={1}](%mat, %7) - %c : Dynamic = aten::addmm[beta={2}, alpha={4.2}](%mat, %mat1, %mat2) - %10 : int = prim::TensorToNum(%alpha) - %11 : int = prim::TensorToNum(%beta) - %d : Dynamic = aten::addmm(%mat, %mat1, %mat2, %11, %10) - %13 : Dynamic = aten::add[alpha={1}](%6, %8) - %14 : Dynamic = aten::add[alpha={1}](%13, %c) - %15 : Dynamic = aten::add[alpha={1}](%14, %d) - return (%15); + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=1]() + %a : Dynamic = aten::addmm(%mat, %mat1, %mat2, %5, %6) + %8 : float = prim::Constant[value=1]() + %9 : float = prim::Constant[value=1]() + %b : Dynamic = aten::addmm(%mat, %mat1, %mat2, %9, %8) + %11 : float = prim::Constant[value=4.2]() + %12 : float = prim::Constant[value=2]() + %c : Dynamic = aten::addmm(%mat, %mat1, %mat2, %12, %11) + %14 : int = prim::TensorToNum(%alpha) + %15 : int = prim::TensorToNum(%beta) + %d : Dynamic = aten::addmm(%mat, %mat1, %mat2, %15, %14) + %17 : int = prim::Constant[value=1]() + %18 : Dynamic = aten::add(%a, %b, %17) + %19 : int = prim::Constant[value=1]() + %20 : Dynamic = aten::add(%18, %c, %19) + %21 : int = prim::Constant[value=1]() + %22 : Dynamic = aten::add(%20, %d, %21) + return (%22); } diff --git a/test/expect/TestJit.test_fuse_last_device.expect b/test/expect/TestJit.test_fuse_last_device.expect index 276fadc61fd7df..e5613bfa975ffd 100644 --- a/test/expect/TestJit.test_fuse_last_device.expect +++ b/test/expect/TestJit.test_fuse_last_device.expect @@ -3,12 +3,14 @@ graph(%0 : Float(1) %2 : Float(1) = prim::FusionGroup_0[device=1](%0, %1) return (%2); } -with prim::FusionGroup_0 = graph(%6 : Float(1) - %9 : Float(1)) { - %10 : Float(1) = aten::add[alpha={1}](%6, %9) - %8 : Float(1) = aten::mul(%6, %10) - %5 : Float(1) = aten::add[other={1}, alpha={1}](%8) - %3 : Float(1) = aten::tanh(%5) +with prim::FusionGroup_0 = graph(%7 : Float(1) + %10 : Float(1)) { + %11 : int = prim::Constant[value=1]() + %12 : Float(1) = aten::add(%7, %10, %11) + %9 : Float(1) = aten::mul(%7, %12) + %5 : int = prim::Constant[value=1]() + %6 : Float(1) = aten::add(%9, %5, %5) + %3 : Float(1) = aten::tanh(%6) %1 : Float(1) = aten::sigmoid(%3) return (%1); } diff --git a/test/expect/TestJit.test_fusion_distribute.expect b/test/expect/TestJit.test_fusion_distribute.expect index 4465074e556585..380a92c8a112d0 100644 --- a/test/expect/TestJit.test_fusion_distribute.expect +++ b/test/expect/TestJit.test_fusion_distribute.expect @@ -1,16 +1,20 @@ graph(%0 : Float(4, 4) %1 : Float(4, 4)) { - %2 : Float(4!, 2), %3 : Float(4!, 2) = aten::chunk[chunks=2, dim=1](%0) - %4 : Float(4!, 2), %5 : Float(4!, 2) = aten::chunk[chunks=2, dim=1](%1) - %6 : Float(4, 2) = prim::FusionGroup_0[device=0](%2, %4, %3, %5) - return (%6); + %2 : int = prim::Constant[value=1]() + %3 : int = prim::Constant[value=2]() + %4 : Float(4!, 2), %5 : Float(4!, 2) = aten::chunk(%0, %3, %2) + %6 : Float(4!, 2), %7 : Float(4!, 2) = aten::chunk(%1, %3, %2) + %8 : Float(4, 2) = prim::FusionGroup_0[device=0](%4, %6, %5, %7) + return (%8); } with prim::FusionGroup_0 = graph(%3 : Float(4!, 2) %4 : Float(4!, 2) - %6 : Float(4!, 2) - %7 : Float(4!, 2)) { - %8 : Float(4, 2) = aten::add[alpha={1}](%6, %7) - %5 : Float(4, 2) = aten::add[alpha={1}](%3, %4) - %2 : Float(4, 2) = aten::mul(%5, %8) + %7 : Float(4!, 2) + %8 : Float(4!, 2)) { + %9 : int = prim::Constant[value=1]() + %10 : Float(4, 2) = aten::add(%7, %8, %9) + %5 : int = prim::Constant[value=1]() + %6 : Float(4, 2) = aten::add(%3, %4, %5) + %2 : Float(4, 2) = aten::mul(%6, %10) return (%2); } diff --git a/test/expect/TestJit.test_inplace_transplant.expect b/test/expect/TestJit.test_inplace_transplant.expect index e31e8c783b62b1..c9a84219a5ed6d 100644 --- a/test/expect/TestJit.test_inplace_transplant.expect +++ b/test/expect/TestJit.test_inplace_transplant.expect @@ -1,6 +1,10 @@ graph(%0 : Double(1)) { %1 : Double(1) = aten::clone(%0) - %2 : Double(1) = aten::add[other={2}, alpha={1}](%1) - %3 : Double(1) = aten::add[other={3}, alpha={1}](%2) - return (%3); + %2 : int = prim::Constant[value=2]() + %3 : int = prim::Constant[value=1]() + %4 : Double(1) = aten::add(%1, %2, %3) + %5 : int = prim::Constant[value=3]() + %6 : int = prim::Constant[value=1]() + %7 : Double(1) = aten::add(%4, %5, %6) + return (%7); } diff --git a/test/expect/TestJit.test_lstm_fusion_concat.expect b/test/expect/TestJit.test_lstm_fusion_concat.expect index 7f6b3f1c8b1b9c..7884a95c48c9a1 100644 --- a/test/expect/TestJit.test_lstm_fusion_concat.expect +++ b/test/expect/TestJit.test_lstm_fusion_concat.expect @@ -6,38 +6,44 @@ graph(%0 : Float(3, 10) %5 : Float(80) %6 : Float(80)) { %7 : Float(10!, 80!) = aten::t(%3) - %8 : Float(3!, 80) = aten::expand[size=[3, 80], implicit=0](%5) - %9 : Float(3, 80) = aten::addmm[alpha={1}, beta={1}](%8, %0, %7) + %8 : int = prim::Constant[value=1]() + %9 : Float(3, 80) = aten::addmm(%5, %0, %7, %8, %8) %10 : Float(20!, 80!) = aten::t(%4) - %11 : Float(3!, 80) = aten::expand[size=[3, 80], implicit=0](%6) - %12 : Float(3, 80) = aten::addmm[alpha={1}, beta={1}](%11, %1, %10) - %13 : Float(3!, 20), %14 : Float(3!, 20), %15 : Float(3!, 20), %16 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%9) - %17 : Float(3!, 20), %18 : Float(3!, 20), %19 : Float(3!, 20), %20 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%12) + %11 : Float(3, 80) = aten::addmm(%6, %1, %10, %8, %8) + %12 : int = prim::Constant[value=4]() + %13 : Float(3!, 20), %14 : Float(3!, 20), %15 : Float(3!, 20), %16 : Float(3!, 20) = aten::chunk(%9, %12, %8) + %17 : Float(3!, 20), %18 : Float(3!, 20), %19 : Float(3!, 20), %20 : Float(3!, 20) = aten::chunk(%11, %12, %8) %21 : Float(6, 20) = prim::FusionGroup_0[device=0](%2, %16, %20, %15, %19, %14, %18, %13, %17) return (%21); } -with prim::FusionGroup_0 = graph(%14 : Float(3, 20) - %24 : Float(3!, 20) - %25 : Float(3!, 20) +with prim::FusionGroup_0 = graph(%16 : Float(3, 20) + %26 : Float(3!, 20) %27 : Float(3!, 20) - %28 : Float(3!, 20) %30 : Float(3!, 20) %31 : Float(3!, 20) - %33 : Float(3!, 20) - %34 : Float(3!, 20)) { - %35 : Float(3, 20) = aten::add[alpha={1}](%33, %34) - %32 : Float(3, 20) = aten::add[alpha={1}](%30, %31) - %29 : Float(3, 20) = aten::add[alpha={1}](%27, %28) - %26 : Float(3, 20) = aten::add[alpha={1}](%24, %25) - %23 : Float(3, 20) = aten::sigmoid(%35) - %21 : Float(3, 20) = aten::sigmoid(%32) - %19 : Float(3, 20) = aten::tanh(%29) - %17 : Float(3, 20) = aten::sigmoid(%26) - %15 : Float(3, 20) = aten::mul(%21, %14) - %12 : Float(3, 20) = aten::mul(%23, %19) - %9 : Float(3, 20) = aten::add[alpha={1}](%15, %12) - %6 : Float(3, 20) = aten::tanh(%9) - %5 : Float(3, 20) = aten::mul(%17, %6) - %2 : Float(6, 20) = aten::cat[dim=0](%5, %9) - return (%2); + %34 : Float(3!, 20) + %35 : Float(3!, 20) + %38 : Float(3!, 20) + %39 : Float(3!, 20)) { + %40 : int = prim::Constant[value=1]() + %41 : Float(3, 20) = aten::add(%38, %39, %40) + %36 : int = prim::Constant[value=1]() + %37 : Float(3, 20) = aten::add(%34, %35, %36) + %32 : int = prim::Constant[value=1]() + %33 : Float(3, 20) = aten::add(%30, %31, %32) + %28 : int = prim::Constant[value=1]() + %29 : Float(3, 20) = aten::add(%26, %27, %28) + %25 : Float(3, 20) = aten::sigmoid(%41) + %23 : Float(3, 20) = aten::sigmoid(%37) + %21 : Float(3, 20) = aten::tanh(%33) + %19 : Float(3, 20) = aten::sigmoid(%29) + %17 : Float(3, 20) = aten::mul(%23, %16) + %14 : Float(3, 20) = aten::mul(%25, %21) + %10 : int = prim::Constant[value=1]() + %11 : Float(3, 20) = aten::add(%17, %14, %10) + %7 : Float(3, 20) = aten::tanh(%11) + %6 : Float(3, 20) = aten::mul(%19, %7) + %2 : int = prim::Constant[value=0]() + %3 : Float(6, 20) = aten::cat(%6, %11, %2) + return (%3); } diff --git a/test/expect/TestJit.test_lstm_fusion_cuda.expect b/test/expect/TestJit.test_lstm_fusion_cuda.expect index f2393996d11415..06be6cbb5d44a1 100644 --- a/test/expect/TestJit.test_lstm_fusion_cuda.expect +++ b/test/expect/TestJit.test_lstm_fusion_cuda.expect @@ -6,37 +6,42 @@ graph(%0 : Float(3, 10) %5 : Float(80) %6 : Float(80)) { %7 : Float(10!, 80!) = aten::t(%3) - %8 : Float(3!, 80) = aten::expand[size=[3, 80], implicit=0](%5) - %9 : Float(3, 80) = aten::addmm[alpha={1}, beta={1}](%8, %0, %7) + %8 : int = prim::Constant[value=1]() + %9 : Float(3, 80) = aten::addmm(%5, %0, %7, %8, %8) %10 : Float(20!, 80!) = aten::t(%4) - %11 : Float(3!, 80) = aten::expand[size=[3, 80], implicit=0](%6) - %12 : Float(3, 80) = aten::addmm[alpha={1}, beta={1}](%11, %1, %10) - %13 : Float(3!, 20), %14 : Float(3!, 20), %15 : Float(3!, 20), %16 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%9) - %17 : Float(3!, 20), %18 : Float(3!, 20), %19 : Float(3!, 20), %20 : Float(3!, 20) = aten::chunk[chunks=4, dim=1](%12) + %11 : Float(3, 80) = aten::addmm(%6, %1, %10, %8, %8) + %12 : int = prim::Constant[value=4]() + %13 : Float(3!, 20), %14 : Float(3!, 20), %15 : Float(3!, 20), %16 : Float(3!, 20) = aten::chunk(%9, %12, %8) + %17 : Float(3!, 20), %18 : Float(3!, 20), %19 : Float(3!, 20), %20 : Float(3!, 20) = aten::chunk(%11, %12, %8) %21 : Float(3, 20), %22 : Float(3, 20) = prim::FusionGroup_0[device=0](%2, %16, %20, %15, %19, %14, %18, %13, %17) return (%21, %22); } -with prim::FusionGroup_0 = graph(%12 : Float(3, 20) - %22 : Float(3!, 20) +with prim::FusionGroup_0 = graph(%13 : Float(3, 20) %23 : Float(3!, 20) - %25 : Float(3!, 20) - %26 : Float(3!, 20) + %24 : Float(3!, 20) + %27 : Float(3!, 20) %28 : Float(3!, 20) - %29 : Float(3!, 20) %31 : Float(3!, 20) - %32 : Float(3!, 20)) { - %33 : Float(3, 20) = aten::add[alpha={1}](%31, %32) - %30 : Float(3, 20) = aten::add[alpha={1}](%28, %29) - %27 : Float(3, 20) = aten::add[alpha={1}](%25, %26) - %24 : Float(3, 20) = aten::add[alpha={1}](%22, %23) - %21 : Float(3, 20) = aten::sigmoid(%33) - %19 : Float(3, 20) = aten::sigmoid(%30) - %17 : Float(3, 20) = aten::tanh(%27) - %15 : Float(3, 20) = aten::sigmoid(%24) - %13 : Float(3, 20) = aten::mul(%19, %12) - %10 : Float(3, 20) = aten::mul(%21, %17) - %7 : Float(3, 20) = aten::add[alpha={1}](%13, %10) - %4 : Float(3, 20) = aten::tanh(%7) - %2 : Float(3, 20) = aten::mul(%15, %4) - return (%2, %7); + %32 : Float(3!, 20) + %35 : Float(3!, 20) + %36 : Float(3!, 20)) { + %37 : int = prim::Constant[value=1]() + %38 : Float(3, 20) = aten::add(%35, %36, %37) + %33 : int = prim::Constant[value=1]() + %34 : Float(3, 20) = aten::add(%31, %32, %33) + %29 : int = prim::Constant[value=1]() + %30 : Float(3, 20) = aten::add(%27, %28, %29) + %25 : int = prim::Constant[value=1]() + %26 : Float(3, 20) = aten::add(%23, %24, %25) + %22 : Float(3, 20) = aten::sigmoid(%38) + %20 : Float(3, 20) = aten::sigmoid(%34) + %18 : Float(3, 20) = aten::tanh(%30) + %16 : Float(3, 20) = aten::sigmoid(%26) + %14 : Float(3, 20) = aten::mul(%20, %13) + %11 : Float(3, 20) = aten::mul(%22, %18) + %7 : int = prim::Constant[value=1]() + %8 : Float(3, 20) = aten::add(%14, %11, %7) + %4 : Float(3, 20) = aten::tanh(%8) + %2 : Float(3, 20) = aten::mul(%16, %4) + return (%2, %8); } diff --git a/test/expect/TestJit.test_nested_inplace.expect b/test/expect/TestJit.test_nested_inplace.expect index ff7e60b1c7d5ab..fd21055854faba 100644 --- a/test/expect/TestJit.test_nested_inplace.expect +++ b/test/expect/TestJit.test_nested_inplace.expect @@ -1,4 +1,6 @@ graph(%0 : Double(2, 2)) { - %1 : Double(2, 2) = aten::threshold[threshold={0}, value={0}](%0) - return (%1); + %1 : int = prim::Constant[value=0]() + %2 : int = prim::Constant[value=0]() + %3 : Double(2, 2) = aten::threshold(%0, %1, %2) + return (%3); } diff --git a/test/expect/TestJit.test_python_ir.expect b/test/expect/TestJit.test_python_ir.expect index 1bb094dcd8bfd3..59ed07b6fdc9f0 100644 --- a/test/expect/TestJit.test_python_ir.expect +++ b/test/expect/TestJit.test_python_ir.expect @@ -1,9 +1,10 @@ graph(%0 : Dynamic %1 : Dynamic) { - %2 : Double(1) = aten::add[alpha={1}](%0, %1) - %3 : Double(1) = aten::mul(%0, %2) - %4 : Double(1) = aten::tanh(%3) - %5 : Double(1) = aten::sigmoid(%4) - %6 : Dynamic = prim::TensorTest[a= 1 1 1 1 [ CPUDoubleType{2,2} ]]() - return (%5); + %2 : int = prim::Constant[value=1]() + %3 : Double(1) = aten::add(%0, %1, %2) + %4 : Double(1) = aten::mul(%0, %3) + %5 : Double(1) = aten::tanh(%4) + %6 : Double(1) = aten::sigmoid(%5) + %7 : Dynamic = prim::TensorTest[a= 1 1 1 1 [ CPUDoubleType{2,2} ]]() + return (%6); } diff --git a/test/expect/TestJit.test_repeated_input.expect b/test/expect/TestJit.test_repeated_input.expect index 57e57066ef503b..ac67a6c14fc972 100644 --- a/test/expect/TestJit.test_repeated_input.expect +++ b/test/expect/TestJit.test_repeated_input.expect @@ -1,5 +1,6 @@ graph(%0 : Double(2, 2) %1 : Double(2, 2)) { - %2 : Double(2, 2) = aten::add[alpha={1}](%0, %1) - return (%2); + %2 : int = prim::Constant[value=1]() + %3 : Double(2, 2) = aten::add(%0, %1, %2) + return (%3); } diff --git a/test/expect/TestJit.test_repeated_output.expect b/test/expect/TestJit.test_repeated_output.expect index b3baff631ebe0d..64a937aef7fb6a 100644 --- a/test/expect/TestJit.test_repeated_output.expect +++ b/test/expect/TestJit.test_repeated_output.expect @@ -1,5 +1,6 @@ graph(%0 : Double(2, 2) %1 : Double(2, 2)) { - %2 : Double(2, 2) = aten::add[alpha={1}](%0, %1) - return (%2, %2); + %2 : int = prim::Constant[value=1]() + %3 : Double(2, 2) = aten::add(%0, %1, %2) + return (%3, %3); } diff --git a/test/expect/TestJit.test_scopes.expect b/test/expect/TestJit.test_scopes.expect index 05578370f5cfbf..3cbbb0b966afe1 100644 --- a/test/expect/TestJit.test_scopes.expect +++ b/test/expect/TestJit.test_scopes.expect @@ -1,8 +1,9 @@ graph(%0 : Double(1) %1 : Double(1)) { - %2 : Double(1) = aten::add[alpha={1}](%0, %1) - %3 : Double(1) = aten::mul(%0, %2), scope: Foo - %4 : Double(1) = aten::tanh(%3), scope: Foo/Bar - %5 : Double(1) = aten::sigmoid(%4), scope: Foo - return (%5); + %2 : int = prim::Constant[value=1]() + %3 : Double(1) = aten::add(%0, %1, %2) + %4 : Double(1) = aten::mul(%0, %3), scope: Foo + %5 : Double(1) = aten::tanh(%4), scope: Foo/Bar + %6 : Double(1) = aten::sigmoid(%5), scope: Foo + return (%6); } diff --git a/test/expect/TestJit.test_shape_analysis_broadcast.expect b/test/expect/TestJit.test_shape_analysis_broadcast.expect index bbe5b741649d0a..e238c3fe1adc13 100644 --- a/test/expect/TestJit.test_shape_analysis_broadcast.expect +++ b/test/expect/TestJit.test_shape_analysis_broadcast.expect @@ -1,7 +1,12 @@ graph(%a : Double(3, 1, 5) %b : Double(4, 1, 8, 5)) { - %2 : Double(4!, 3!, 8!, 5) = aten::expand[size=[4, 3, 8, 5], implicit=0](%a) - %3 : Double(4!, 3!, 8, 5) = aten::expand[size=[4, 3, 8, 5], implicit=0](%b) - %4 : Double(4, 3, 8, 5) = aten::add[alpha={1}](%2, %3) - return (%4); + %2 : int = prim::Constant[value=1]() + %3 : int[] = prim::Constant[value=[4, 3, 8, 5]]() + %4 : int = prim::Constant[value=0]() + %5 : Double(4!, 3!, 8!, 5) = aten::expand(%a, %3, %4) + %6 : int[] = prim::Constant[value=[4, 3, 8, 5]]() + %7 : int = prim::Constant[value=0]() + %8 : Double(4!, 3!, 8, 5) = aten::expand(%b, %6, %7) + %9 : Double(4, 3, 8, 5) = aten::add(%5, %8, %2) + return (%9); } diff --git a/test/expect/TestJit.test_shared_param.expect b/test/expect/TestJit.test_shared_param.expect index ec758dfb7d87af..1b0a2c25be34bb 100644 --- a/test/expect/TestJit.test_shared_param.expect +++ b/test/expect/TestJit.test_shared_param.expect @@ -1,6 +1,7 @@ graph(%0 : Double(2, 2) %1 : Double(2, 2)) { %2 : Double(2, 2) = aten::mul(%0, %1), scope: MyModule - %3 : Double(2, 2) = aten::add[alpha={1}](%2, %1), scope: MyModule - return (%3); + %3 : int = prim::Constant[value=1](), scope: MyModule + %4 : Double(2, 2) = aten::add(%2, %1, %3), scope: MyModule + return (%4); } diff --git a/test/expect/TestJit.test_simple.expect b/test/expect/TestJit.test_simple.expect index 1db84de676b3e5..bfa7408b6be17c 100644 --- a/test/expect/TestJit.test_simple.expect +++ b/test/expect/TestJit.test_simple.expect @@ -1,8 +1,9 @@ graph(%0 : Double(1) %1 : Double(1)) { - %2 : Double(1) = aten::add[alpha={1}](%0, %1) - %3 : Double(1) = aten::mul(%0, %2) - %4 : Double(1) = aten::tanh(%3) - %5 : Double(1) = aten::sigmoid(%4) - return (%5); + %2 : int = prim::Constant[value=1]() + %3 : Double(1) = aten::add(%0, %1, %2) + %4 : Double(1) = aten::mul(%0, %3) + %5 : Double(1) = aten::tanh(%4) + %6 : Double(1) = aten::sigmoid(%5) + return (%6); } diff --git a/test/expect/TestJit.test_trace_size.expect b/test/expect/TestJit.test_trace_size.expect index 567a0fc5a5ecb3..8068691735c3dd 100644 --- a/test/expect/TestJit.test_trace_size.expect +++ b/test/expect/TestJit.test_trace_size.expect @@ -2,14 +2,15 @@ graph(%0 : Double(5, 2, 4)) { %1 : int = prim::Constant[value=1]() %2 : int = aten::size(%0, %1) %3 : Long() = prim::NumToTensor(%2) - %4 : Long() = aten::mul[other={2}](%3) - %5 : int = prim::TensorToNum(%4) - %6 : int = prim::Constant[value=0]() - %7 : int = aten::size(%0, %6) - %8 : Long() = prim::NumToTensor(%7) - %9 : int = prim::TensorToNum(%8) - %10 : int = prim::Constant[value=2]() - %11 : int[] = prim::ListConstruct(%5, %9, %10) - %12 : Double(4, 5, 2) = aten::view(%0, %11) - return (%12); + %4 : int = prim::Constant[value=2]() + %5 : Long() = aten::mul(%3, %4) + %6 : int = prim::TensorToNum(%5) + %7 : int = prim::Constant[value=0]() + %8 : int = aten::size(%0, %7) + %9 : Long() = prim::NumToTensor(%8) + %10 : int = prim::TensorToNum(%9) + %11 : int = prim::Constant[value=2]() + %12 : int[] = prim::ListConstruct(%6, %10, %11) + %13 : Double(4, 5, 2) = aten::view(%0, %12) + return (%13); } diff --git a/test/expect/TestJit.test_trace_size_with_grad.expect b/test/expect/TestJit.test_trace_size_with_grad.expect index 567a0fc5a5ecb3..8068691735c3dd 100644 --- a/test/expect/TestJit.test_trace_size_with_grad.expect +++ b/test/expect/TestJit.test_trace_size_with_grad.expect @@ -2,14 +2,15 @@ graph(%0 : Double(5, 2, 4)) { %1 : int = prim::Constant[value=1]() %2 : int = aten::size(%0, %1) %3 : Long() = prim::NumToTensor(%2) - %4 : Long() = aten::mul[other={2}](%3) - %5 : int = prim::TensorToNum(%4) - %6 : int = prim::Constant[value=0]() - %7 : int = aten::size(%0, %6) - %8 : Long() = prim::NumToTensor(%7) - %9 : int = prim::TensorToNum(%8) - %10 : int = prim::Constant[value=2]() - %11 : int[] = prim::ListConstruct(%5, %9, %10) - %12 : Double(4, 5, 2) = aten::view(%0, %11) - return (%12); + %4 : int = prim::Constant[value=2]() + %5 : Long() = aten::mul(%3, %4) + %6 : int = prim::TensorToNum(%5) + %7 : int = prim::Constant[value=0]() + %8 : int = aten::size(%0, %7) + %9 : Long() = prim::NumToTensor(%8) + %10 : int = prim::TensorToNum(%9) + %11 : int = prim::Constant[value=2]() + %12 : int[] = prim::ListConstruct(%6, %10, %11) + %13 : Double(4, 5, 2) = aten::view(%0, %12) + return (%13); } diff --git a/test/expect/TestScript.test_call_python_fn_from_script_fn.expect b/test/expect/TestScript.test_call_python_fn_from_script_fn.expect index db478d2e22f9cb..297d0918700a02 100644 --- a/test/expect/TestScript.test_call_python_fn_from_script_fn.expect +++ b/test/expect/TestScript.test_call_python_fn_from_script_fn.expect @@ -1,5 +1,7 @@ graph(%x : Dynamic) { %1 : Dynamic = ^python_fn()(%x) - %5 : Dynamic = aten::add[other={1}, alpha={1}](%1) + %2 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=1]() + %5 : Dynamic = aten::add(%1, %2, %4) return (%5); } diff --git a/test/expect/TestScript.test_call_python_fn_from_tracing_fn.expect b/test/expect/TestScript.test_call_python_fn_from_tracing_fn.expect index 4eb19bbc83c7e7..ac76985db76fb3 100644 --- a/test/expect/TestScript.test_call_python_fn_from_tracing_fn.expect +++ b/test/expect/TestScript.test_call_python_fn_from_tracing_fn.expect @@ -1,5 +1,7 @@ graph(%0 : Double(3, 4)) { %1 : Double(3, 4) = aten::neg(%0) - %2 : Double(3, 4) = aten::add[other={1}, alpha={1}](%1) - return (%2); + %2 : int = prim::Constant[value=1]() + %3 : int = prim::Constant[value=1]() + %4 : Double(3, 4) = aten::add(%1, %2, %3) + return (%4); } diff --git a/test/expect/TestScript.test_call_python_mod_from_script_fn.expect b/test/expect/TestScript.test_call_python_mod_from_script_fn.expect index ec5fd842f3b864..260bbaba6462f7 100644 --- a/test/expect/TestScript.test_call_python_mod_from_script_fn.expect +++ b/test/expect/TestScript.test_call_python_mod_from_script_fn.expect @@ -1,5 +1,7 @@ graph(%x : Dynamic) { %1 : Dynamic = ^()(%x) - %5 : Dynamic = aten::add[other={1}, alpha={1}](%1) + %2 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=1]() + %5 : Dynamic = aten::add(%1, %2, %4) return (%5); } diff --git a/test/expect/TestScript.test_call_python_mod_from_traced_module.expect b/test/expect/TestScript.test_call_python_mod_from_traced_module.expect index d39acaf5257d3d..863cbdf2d5a4fa 100644 --- a/test/expect/TestScript.test_call_python_mod_from_traced_module.expect +++ b/test/expect/TestScript.test_call_python_mod_from_traced_module.expect @@ -3,6 +3,8 @@ graph(%0 : Double(3, 4) %2 : Double(5, 7)) { %4 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule %6 : Double(3, 7) = aten::mm(%4, %2), scope: TracedModule/PythonModule[mod] - %7 : Double(3, 7) = aten::add[other={1}, alpha={1}](%6), scope: TracedModule - return (%7); + %7 : int = prim::Constant[value=1](), scope: TracedModule + %8 : int = prim::Constant[value=1](), scope: TracedModule + %9 : Double(3, 7) = aten::add(%6, %7, %8), scope: TracedModule + return (%9); } diff --git a/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect index ea847d630c8ba5..a23a6bd2730368 100644 --- a/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect +++ b/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect @@ -1,6 +1,8 @@ graph(%0 : Double(3, 4)) { %1 : Double(4, 3) = prim::Constant[value=](), scope: PythonMod %3 : Double(3, 3) = aten::mm(%0, %1), scope: PythonMod - %4 : Double(3, 3) = aten::add[other={1}, alpha={1}](%3) - return (%4); + %4 : int = prim::Constant[value=1]() + %5 : int = prim::Constant[value=1]() + %6 : Double(3, 3) = aten::add(%3, %4, %5) + return (%6); } diff --git a/test/expect/TestScript.test_call_script_fn_from_script_fn.expect b/test/expect/TestScript.test_call_script_fn_from_script_fn.expect index e36a68926dccce..8c23ad8f353bef 100644 --- a/test/expect/TestScript.test_call_script_fn_from_script_fn.expect +++ b/test/expect/TestScript.test_call_script_fn_from_script_fn.expect @@ -1,5 +1,7 @@ graph(%x : Dynamic) { %1 : Dynamic = aten::neg(%x) - %5 : Dynamic = aten::add[other={1}, alpha={1}](%1) + %2 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=1]() + %5 : Dynamic = aten::add(%1, %2, %4) return (%5); } diff --git a/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect b/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect index dc8b4945df4773..d12a6a40520bdf 100644 --- a/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect +++ b/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect @@ -1,5 +1,7 @@ graph(%0 : Double(3, 4)) { %2 : Double(3, 4) = aten::neg(%0), scope: ScriptModule - %3 : Double(3, 4) = aten::add[other={1}, alpha={1}](%2) - return (%3); + %3 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=1]() + %5 : Double(3, 4) = aten::add(%2, %3, %4) + return (%5); } diff --git a/test/expect/TestScript.test_call_script_mod_from_script_fn.expect b/test/expect/TestScript.test_call_script_mod_from_script_fn.expect index e24d034b26e3da..98cf4ade03b461 100644 --- a/test/expect/TestScript.test_call_script_mod_from_script_fn.expect +++ b/test/expect/TestScript.test_call_script_mod_from_script_fn.expect @@ -7,6 +7,8 @@ graph(%x : Dynamic) { %6 : int[] = prim::ListConstruct(%1, %2) %7 : Dynamic = aten::zeros(%6, %3, %4, %5) %8 : Dynamic = aten::mm(%x, %7) - %12 : Dynamic = aten::add[other={1}, alpha={1}](%8) + %9 : int = prim::Constant[value=1]() + %11 : int = prim::Constant[value=1]() + %12 : Dynamic = aten::add(%8, %9, %11) return (%12); } diff --git a/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect index fc7039bd971f23..4c098d3b9f16a4 100644 --- a/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect +++ b/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect @@ -1,6 +1,8 @@ graph(%0 : Double(3, 4)) { %1 : Double(4, 3) = prim::Constant[value=](), scope: ScriptMod %4 : Double(3, 3) = aten::mm(%0, %1), scope: ScriptMod - %5 : Double(3, 3) = aten::add[other={1}, alpha={1}](%4) - return (%5); + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=1]() + %7 : Double(3, 3) = aten::add(%4, %5, %6) + return (%7); } diff --git a/test/expect/TestScript.test_call_script_module_from_traced_module.expect b/test/expect/TestScript.test_call_script_module_from_traced_module.expect index 21b14a2a62f8cf..1a452935eb5fc8 100644 --- a/test/expect/TestScript.test_call_script_module_from_traced_module.expect +++ b/test/expect/TestScript.test_call_script_module_from_traced_module.expect @@ -3,6 +3,8 @@ graph(%0 : Double(3, 4) %2 : Double(5, 7)) { %4 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule %7 : Double(3, 7) = aten::mm(%4, %2), scope: TracedModule/ScriptMod[mod] - %8 : Double(3, 7) = aten::add[other={1}, alpha={1}](%7), scope: TracedModule - return (%8); + %8 : int = prim::Constant[value=1](), scope: TracedModule + %9 : int = prim::Constant[value=1](), scope: TracedModule + %10 : Double(3, 7) = aten::add(%7, %8, %9), scope: TracedModule + return (%10); } diff --git a/test/expect/TestScript.test_call_traced_fn_from_script_fn.expect b/test/expect/TestScript.test_call_traced_fn_from_script_fn.expect index 83ce62e68e086d..2cae32e2be01a7 100644 --- a/test/expect/TestScript.test_call_traced_fn_from_script_fn.expect +++ b/test/expect/TestScript.test_call_traced_fn_from_script_fn.expect @@ -1,5 +1,7 @@ graph(%x : Dynamic) { %1 : Double(3, 4) = aten::neg(%x) - %5 : Dynamic = aten::add[other={1}, alpha={1}](%1) + %2 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=1]() + %5 : Dynamic = aten::add(%1, %2, %4) return (%5); } diff --git a/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect b/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect index ed737f4b6580b4..27eb2b6e7814e0 100644 --- a/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect +++ b/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect @@ -1,5 +1,7 @@ graph(%0 : Double(3, 4)) { %2 : Double(3, 4) = aten::neg(%0), scope: traced_fn1 - %3 : Double(3, 4) = aten::add[other={1}, alpha={1}](%2) - return (%3); + %3 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=1]() + %5 : Double(3, 4) = aten::add(%2, %3, %4) + return (%5); } diff --git a/test/expect/TestScript.test_call_traced_mod_from_script_fn.expect b/test/expect/TestScript.test_call_traced_mod_from_script_fn.expect index 9a99fbe83f1d4d..315ec3464487be 100644 --- a/test/expect/TestScript.test_call_traced_mod_from_script_fn.expect +++ b/test/expect/TestScript.test_call_traced_mod_from_script_fn.expect @@ -1,6 +1,8 @@ graph(%x : Dynamic) { %1 : Double(4, 3) = prim::Constant[value=]() %2 : Double(3, 3) = aten::mm(%x, %1) - %6 : Dynamic = aten::add[other={1}, alpha={1}](%2) + %3 : int = prim::Constant[value=1]() + %5 : int = prim::Constant[value=1]() + %6 : Dynamic = aten::add(%2, %3, %5) return (%6); } diff --git a/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect index 3fac45fc2dfdab..f5c6f1bb2c18d8 100644 --- a/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect +++ b/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect @@ -1,6 +1,8 @@ graph(%0 : Double(3, 4)) { %1 : Double(4, 3) = prim::Constant[value=](), scope: TracedModule[TracedModule] %4 : Double(3, 3) = aten::mm(%0, %1), scope: TracedModule[TracedModule] - %5 : Double(3, 3) = aten::add[other={1}, alpha={1}](%4) - return (%5); + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=1]() + %7 : Double(3, 3) = aten::add(%4, %5, %6) + return (%7); } diff --git a/test/expect/TestScript.test_call_traced_module_from_traced_module.expect b/test/expect/TestScript.test_call_traced_module_from_traced_module.expect index 471f9f1c2ec3fe..f66573f6da2f25 100644 --- a/test/expect/TestScript.test_call_traced_module_from_traced_module.expect +++ b/test/expect/TestScript.test_call_traced_module_from_traced_module.expect @@ -3,6 +3,8 @@ graph(%0 : Double(3, 4) %2 : Double(5, 7)) { %4 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule %7 : Double(3, 7) = aten::mm(%4, %2), scope: TracedModule/TracedModule[TracedModule1][mod] - %8 : Double(3, 7) = aten::add[other={1}, alpha={1}](%7), scope: TracedModule - return (%8); + %8 : int = prim::Constant[value=1](), scope: TracedModule + %9 : int = prim::Constant[value=1](), scope: TracedModule + %10 : Double(3, 7) = aten::add(%7, %8, %9), scope: TracedModule + return (%10); } diff --git a/test/expect/TestScript.test_cat_lifts.expect b/test/expect/TestScript.test_cat_lifts.expect index 5bcef43f7c7a3d..ea2fa3737c0556 100644 --- a/test/expect/TestScript.test_cat_lifts.expect +++ b/test/expect/TestScript.test_cat_lifts.expect @@ -1,12 +1,15 @@ graph(%x : Dynamic) { - %1 : Dynamic = aten::cat[dim=1](%x, %x) - return (%1); + %1 : int = prim::Constant[value=1]() + %2 : Dynamic = aten::cat(%x, %x, %1) + return (%2); } graph(%x : Dynamic) { - %1 : Dynamic = aten::cat[dim=1]() - return (%1); + %1 : int = prim::Constant[value=1]() + %2 : Dynamic = aten::cat(%1) + return (%2); } graph(%x : Dynamic) { - %1 : Dynamic = aten::cat[dim=1](%x) - return (%1); + %1 : int = prim::Constant[value=1]() + %2 : Dynamic = aten::cat(%x, %1) + return (%2); } diff --git a/test/expect/TestScript.test_index_put_trace_with_view.expect b/test/expect/TestScript.test_index_put_trace_with_view.expect index 24ff0fe32c451f..591e499da96671 100644 --- a/test/expect/TestScript.test_index_put_trace_with_view.expect +++ b/test/expect/TestScript.test_index_put_trace_with_view.expect @@ -1,8 +1,11 @@ graph(%0 : Double(100) %1 : Long(4) %2 : Double(1, 1, 1, 4)) { - %3 : Double(4) = aten::view[size=[4]](%2) - %4 : Long(4) = aten::_cast_Long[non_blocking=0](%1) - %11 : Double(100) = aten::index_put(%0, %4, %3) - return (%11); + %3 : int = prim::Constant[value=4]() + %4 : int[] = prim::ListConstruct(%3) + %5 : Double(4) = aten::view(%2, %4) + %6 : int = prim::Constant[value=0]() + %7 : Long(4) = aten::_cast_Long(%1, %6) + %19 : Double(100) = aten::index_put(%0, %7, %5) + return (%19); } diff --git a/test/expect/TestScript.test_index_put_trace_without_view.expect b/test/expect/TestScript.test_index_put_trace_without_view.expect index f483213b481461..42f8e49142942e 100644 --- a/test/expect/TestScript.test_index_put_trace_without_view.expect +++ b/test/expect/TestScript.test_index_put_trace_without_view.expect @@ -1,7 +1,8 @@ graph(%0 : Double(100) %1 : Long(4) %2 : Double(4)) { - %3 : Long(4) = aten::_cast_Long[non_blocking=0](%1) - %10 : Double(100) = aten::index_put(%0, %3, %2) - return (%10); + %3 : int = prim::Constant[value=0]() + %4 : Long(4) = aten::_cast_Long(%1, %3) + %16 : Double(100) = aten::index_put(%0, %4, %2) + return (%16); } diff --git a/test/expect/TestScript.test_index_select_shape_prop.expect b/test/expect/TestScript.test_index_select_shape_prop.expect index 32a9d7744e52cc..f24249a21f9d20 100644 --- a/test/expect/TestScript.test_index_select_shape_prop.expect +++ b/test/expect/TestScript.test_index_select_shape_prop.expect @@ -1,5 +1,6 @@ graph(%x : Double(2, 2) %y : Long(4)) { - %2 : Double(2, 4) = aten::index_select[dim=1](%x, %y) - return (%2); + %2 : int = prim::Constant[value=1]() + %3 : Dynamic = aten::index_select(%x, %2, %y) + return (%3); } diff --git a/test/expect/TestScript.test_loop_unroll_unused_counter.expect b/test/expect/TestScript.test_loop_unroll_unused_counter.expect index a4b5983c1e5d9c..be1a5efeecf449 100644 --- a/test/expect/TestScript.test_loop_unroll_unused_counter.expect +++ b/test/expect/TestScript.test_loop_unroll_unused_counter.expect @@ -9,22 +9,40 @@ graph(%x : Dynamic) { %8 : int = aten::sub(%2, %7) %y.3 : Dynamic = prim::Loop(%5, %3, %y.1) block0(%i.1 : int, %11 : Dynamic) { - %y.12 : Dynamic = aten::add[other={1}, alpha={1}](%11) - %y.5 : Dynamic = aten::add[other={1}, alpha={1}](%y.12) - %y.6 : Dynamic = aten::add[other={1}, alpha={1}](%y.5) - %y.7 : Dynamic = aten::add[other={1}, alpha={1}](%y.6) - %y.8 : Dynamic = aten::add[other={1}, alpha={1}](%y.7) - %y.9 : Dynamic = aten::add[other={1}, alpha={1}](%y.8) - %y.10 : Dynamic = aten::add[other={1}, alpha={1}](%y.9) - %y.11 : Dynamic = aten::add[other={1}, alpha={1}](%y.10) - %20 : int = prim::Constant[value=1]() - -> (%20, %y.11) + %12 : int = prim::Constant[value=1]() + %13 : int = prim::Constant[value=1]() + %y.12 : Dynamic = aten::add(%11, %12, %13) + %15 : int = prim::Constant[value=1]() + %16 : int = prim::Constant[value=1]() + %y.5 : Dynamic = aten::add(%y.12, %15, %16) + %18 : int = prim::Constant[value=1]() + %19 : int = prim::Constant[value=1]() + %y.6 : Dynamic = aten::add(%y.5, %18, %19) + %21 : int = prim::Constant[value=1]() + %22 : int = prim::Constant[value=1]() + %y.7 : Dynamic = aten::add(%y.6, %21, %22) + %24 : int = prim::Constant[value=1]() + %25 : int = prim::Constant[value=1]() + %y.8 : Dynamic = aten::add(%y.7, %24, %25) + %27 : int = prim::Constant[value=1]() + %28 : int = prim::Constant[value=1]() + %y.9 : Dynamic = aten::add(%y.8, %27, %28) + %30 : int = prim::Constant[value=1]() + %31 : int = prim::Constant[value=1]() + %y.10 : Dynamic = aten::add(%y.9, %30, %31) + %33 : int = prim::Constant[value=1]() + %34 : int = prim::Constant[value=1]() + %y.11 : Dynamic = aten::add(%y.10, %33, %34) + %36 : int = prim::Constant[value=1]() + -> (%36, %y.11) } %y : Dynamic = prim::Loop(%8, %3, %y.3) - block0(%i : int, %23 : Dynamic) { - %y.4 : Dynamic = aten::add[other={1}, alpha={1}](%23) - %25 : int = prim::Constant[value=1]() - -> (%25, %y.4) + block0(%i : int, %39 : Dynamic) { + %40 : int = prim::Constant[value=1]() + %41 : int = prim::Constant[value=1]() + %y.4 : Dynamic = aten::add(%39, %40, %41) + %43 : int = prim::Constant[value=1]() + -> (%43, %y.4) } return (%y); } diff --git a/test/expect/TestScript.test_loop_unrolling.expect b/test/expect/TestScript.test_loop_unrolling.expect index 0c77a4ec47e6ec..fc0ca446112036 100644 --- a/test/expect/TestScript.test_loop_unrolling.expect +++ b/test/expect/TestScript.test_loop_unrolling.expect @@ -10,35 +10,35 @@ graph(%x : Dynamic) { %9 : int = aten::sub(%2, %8) %10 : Dynamic, %y.3 : Dynamic = prim::Loop(%6, %3, %4, %y.1) block0(%i.1 : int, %13 : Dynamic, %14 : Dynamic) { - %15 : Number = prim::Constant[value=1]() + %15 : int = prim::Constant[value=1]() %y.12 : Dynamic = aten::add(%14, %13, %15) %17 : int = prim::Constant[value=1]() %18 : int = aten::add(%13, %17) - %19 : Number = prim::Constant[value=1]() + %19 : int = prim::Constant[value=1]() %y.5 : Dynamic = aten::add(%y.12, %18, %19) %21 : int = prim::Constant[value=1]() %22 : int = aten::add(%18, %21) - %23 : Number = prim::Constant[value=1]() + %23 : int = prim::Constant[value=1]() %y.6 : Dynamic = aten::add(%y.5, %22, %23) %25 : int = prim::Constant[value=1]() %26 : int = aten::add(%22, %25) - %27 : Number = prim::Constant[value=1]() + %27 : int = prim::Constant[value=1]() %y.7 : Dynamic = aten::add(%y.6, %26, %27) %29 : int = prim::Constant[value=1]() %30 : int = aten::add(%26, %29) - %31 : Number = prim::Constant[value=1]() + %31 : int = prim::Constant[value=1]() %y.8 : Dynamic = aten::add(%y.7, %30, %31) %33 : int = prim::Constant[value=1]() %34 : int = aten::add(%30, %33) - %35 : Number = prim::Constant[value=1]() + %35 : int = prim::Constant[value=1]() %y.9 : Dynamic = aten::add(%y.8, %34, %35) %37 : int = prim::Constant[value=1]() %38 : int = aten::add(%34, %37) - %39 : Number = prim::Constant[value=1]() + %39 : int = prim::Constant[value=1]() %y.10 : Dynamic = aten::add(%y.9, %38, %39) %41 : int = prim::Constant[value=1]() %42 : int = aten::add(%38, %41) - %43 : Number = prim::Constant[value=1]() + %43 : int = prim::Constant[value=1]() %y.11 : Dynamic = aten::add(%y.10, %42, %43) %45 : int = prim::Constant[value=1]() %46 : int = prim::Constant[value=1]() @@ -47,7 +47,7 @@ graph(%x : Dynamic) { } %48 : Dynamic, %y : Dynamic = prim::Loop(%9, %3, %10, %y.3) block0(%i : int, %51 : Dynamic, %52 : Dynamic) { - %53 : Number = prim::Constant[value=1]() + %53 : int = prim::Constant[value=1]() %y.4 : Dynamic = aten::add(%52, %51, %53) %55 : int = prim::Constant[value=1]() %56 : int = prim::Constant[value=1]() diff --git a/test/expect/TestScript.test_loop_unrolling_const-add_const.expect b/test/expect/TestScript.test_loop_unrolling_const-add_const.expect index 8f810b0a6339bf..6b7d615a1b7800 100644 --- a/test/expect/TestScript.test_loop_unrolling_const-add_const.expect +++ b/test/expect/TestScript.test_loop_unrolling_const-add_const.expect @@ -1,14 +1,34 @@ graph() { %y.1 : Dynamic = ^FIXME_zerol()() - %y.11 : Dynamic = aten::add[other={1}, alpha={1}](%y.1) - %y.2 : Dynamic = aten::add[other={1}, alpha={1}](%y.11) - %y.3 : Dynamic = aten::add[other={1}, alpha={1}](%y.2) - %y.4 : Dynamic = aten::add[other={1}, alpha={1}](%y.3) - %y.5 : Dynamic = aten::add[other={1}, alpha={1}](%y.4) - %y.6 : Dynamic = aten::add[other={1}, alpha={1}](%y.5) - %y.7 : Dynamic = aten::add[other={1}, alpha={1}](%y.6) - %y.8 : Dynamic = aten::add[other={1}, alpha={1}](%y.7) - %y.9 : Dynamic = aten::add[other={1}, alpha={1}](%y.8) - %y.10 : Dynamic = aten::add[other={1}, alpha={1}](%y.9) + %1 : int = prim::Constant[value=1]() + %2 : int = prim::Constant[value=1]() + %y.11 : Dynamic = aten::add(%y.1, %1, %2) + %4 : int = prim::Constant[value=1]() + %5 : int = prim::Constant[value=1]() + %y.2 : Dynamic = aten::add(%y.11, %4, %5) + %7 : int = prim::Constant[value=1]() + %8 : int = prim::Constant[value=1]() + %y.3 : Dynamic = aten::add(%y.2, %7, %8) + %10 : int = prim::Constant[value=1]() + %11 : int = prim::Constant[value=1]() + %y.4 : Dynamic = aten::add(%y.3, %10, %11) + %13 : int = prim::Constant[value=1]() + %14 : int = prim::Constant[value=1]() + %y.5 : Dynamic = aten::add(%y.4, %13, %14) + %16 : int = prim::Constant[value=1]() + %17 : int = prim::Constant[value=1]() + %y.6 : Dynamic = aten::add(%y.5, %16, %17) + %19 : int = prim::Constant[value=1]() + %20 : int = prim::Constant[value=1]() + %y.7 : Dynamic = aten::add(%y.6, %19, %20) + %22 : int = prim::Constant[value=1]() + %23 : int = prim::Constant[value=1]() + %y.8 : Dynamic = aten::add(%y.7, %22, %23) + %25 : int = prim::Constant[value=1]() + %26 : int = prim::Constant[value=1]() + %y.9 : Dynamic = aten::add(%y.8, %25, %26) + %28 : int = prim::Constant[value=1]() + %29 : int = prim::Constant[value=1]() + %y.10 : Dynamic = aten::add(%y.9, %28, %29) return (%y.10); } diff --git a/test/expect/TestScript.test_loop_unrolling_const-add_iter.expect b/test/expect/TestScript.test_loop_unrolling_const-add_iter.expect index 2618493dc8ecb4..ba142cc8092cd3 100644 --- a/test/expect/TestScript.test_loop_unrolling_const-add_iter.expect +++ b/test/expect/TestScript.test_loop_unrolling_const-add_iter.expect @@ -1,43 +1,43 @@ graph() { %y.1 : Dynamic = ^FIXME_zerol()() %1 : int = prim::Constant[value=0]() - %2 : Number = prim::Constant[value=1]() + %2 : int = prim::Constant[value=1]() %y.11 : Dynamic = aten::add(%y.1, %1, %2) %4 : int = prim::Constant[value=1]() %5 : int = aten::add(%1, %4) - %6 : Number = prim::Constant[value=1]() + %6 : int = prim::Constant[value=1]() %y.2 : Dynamic = aten::add(%y.11, %5, %6) %8 : int = prim::Constant[value=1]() %9 : int = aten::add(%5, %8) - %10 : Number = prim::Constant[value=1]() + %10 : int = prim::Constant[value=1]() %y.3 : Dynamic = aten::add(%y.2, %9, %10) %12 : int = prim::Constant[value=1]() %13 : int = aten::add(%9, %12) - %14 : Number = prim::Constant[value=1]() + %14 : int = prim::Constant[value=1]() %y.4 : Dynamic = aten::add(%y.3, %13, %14) %16 : int = prim::Constant[value=1]() %17 : int = aten::add(%13, %16) - %18 : Number = prim::Constant[value=1]() + %18 : int = prim::Constant[value=1]() %y.5 : Dynamic = aten::add(%y.4, %17, %18) %20 : int = prim::Constant[value=1]() %21 : int = aten::add(%17, %20) - %22 : Number = prim::Constant[value=1]() + %22 : int = prim::Constant[value=1]() %y.6 : Dynamic = aten::add(%y.5, %21, %22) %24 : int = prim::Constant[value=1]() %25 : int = aten::add(%21, %24) - %26 : Number = prim::Constant[value=1]() + %26 : int = prim::Constant[value=1]() %y.7 : Dynamic = aten::add(%y.6, %25, %26) %28 : int = prim::Constant[value=1]() %29 : int = aten::add(%25, %28) - %30 : Number = prim::Constant[value=1]() + %30 : int = prim::Constant[value=1]() %y.8 : Dynamic = aten::add(%y.7, %29, %30) %32 : int = prim::Constant[value=1]() %33 : int = aten::add(%29, %32) - %34 : Number = prim::Constant[value=1]() + %34 : int = prim::Constant[value=1]() %y.9 : Dynamic = aten::add(%y.8, %33, %34) %36 : int = prim::Constant[value=1]() %37 : int = aten::add(%33, %36) - %38 : Number = prim::Constant[value=1]() + %38 : int = prim::Constant[value=1]() %y.10 : Dynamic = aten::add(%y.9, %37, %38) return (%y.10); } diff --git a/test/expect/TestScript.test_loop_unrolling_nested.expect b/test/expect/TestScript.test_loop_unrolling_nested.expect index 3b8832d03071a2..cac82d3f3ba210 100644 --- a/test/expect/TestScript.test_loop_unrolling_nested.expect +++ b/test/expect/TestScript.test_loop_unrolling_nested.expect @@ -14,35 +14,35 @@ graph(%x : Dynamic) { %14 : int = aten::sub(%7, %13) %15 : Dynamic, %y.4 : Dynamic = prim::Loop(%11, %8, %9, %6) block0(%j.1 : int, %18 : Dynamic, %19 : Dynamic) { - %20 : Number = prim::Constant[value=1]() + %20 : int = prim::Constant[value=1]() %y.13 : Dynamic = aten::add(%19, %18, %20) %22 : int = prim::Constant[value=1]() %23 : int = aten::add(%18, %22) - %24 : Number = prim::Constant[value=1]() + %24 : int = prim::Constant[value=1]() %y.6 : Dynamic = aten::add(%y.13, %23, %24) %26 : int = prim::Constant[value=1]() %27 : int = aten::add(%23, %26) - %28 : Number = prim::Constant[value=1]() + %28 : int = prim::Constant[value=1]() %y.7 : Dynamic = aten::add(%y.6, %27, %28) %30 : int = prim::Constant[value=1]() %31 : int = aten::add(%27, %30) - %32 : Number = prim::Constant[value=1]() + %32 : int = prim::Constant[value=1]() %y.8 : Dynamic = aten::add(%y.7, %31, %32) %34 : int = prim::Constant[value=1]() %35 : int = aten::add(%31, %34) - %36 : Number = prim::Constant[value=1]() + %36 : int = prim::Constant[value=1]() %y.9 : Dynamic = aten::add(%y.8, %35, %36) %38 : int = prim::Constant[value=1]() %39 : int = aten::add(%35, %38) - %40 : Number = prim::Constant[value=1]() + %40 : int = prim::Constant[value=1]() %y.10 : Dynamic = aten::add(%y.9, %39, %40) %42 : int = prim::Constant[value=1]() %43 : int = aten::add(%39, %42) - %44 : Number = prim::Constant[value=1]() + %44 : int = prim::Constant[value=1]() %y.11 : Dynamic = aten::add(%y.10, %43, %44) %46 : int = prim::Constant[value=1]() %47 : int = aten::add(%43, %46) - %48 : Number = prim::Constant[value=1]() + %48 : int = prim::Constant[value=1]() %y.12 : Dynamic = aten::add(%y.11, %47, %48) %50 : int = prim::Constant[value=1]() %51 : int = prim::Constant[value=1]() @@ -51,7 +51,7 @@ graph(%x : Dynamic) { } %53 : Dynamic, %y.3 : Dynamic = prim::Loop(%14, %8, %15, %y.4) block0(%j : int, %56 : Dynamic, %57 : Dynamic) { - %58 : Number = prim::Constant[value=1]() + %58 : int = prim::Constant[value=1]() %y.5 : Dynamic = aten::add(%57, %56, %58) %60 : int = prim::Constant[value=1]() %61 : int = prim::Constant[value=1]() diff --git a/test/expect/TestScript.test_math_schema.expect b/test/expect/TestScript.test_math_schema.expect index 7d8f8d2800e84c..cff719dabb8ec6 100644 --- a/test/expect/TestScript.test_math_schema.expect +++ b/test/expect/TestScript.test_math_schema.expect @@ -1,5 +1,6 @@ graph(%x : Dynamic %y : Dynamic) { - %2 : Dynamic = aten::add[alpha={1}](%x, %y) - return (%2); + %2 : int = prim::Constant[value=1]() + %3 : Dynamic = aten::add(%x, %y, %2) + return (%3); } diff --git a/test/expect/TestScript.test_math_tensor_number.expect b/test/expect/TestScript.test_math_tensor_number.expect index c0a88913280e59..fb4b81bd00cba5 100644 --- a/test/expect/TestScript.test_math_tensor_number.expect +++ b/test/expect/TestScript.test_math_tensor_number.expect @@ -1,4 +1,6 @@ graph(%x : Dynamic) { - %1 : Dynamic = aten::add[other={7}, alpha={1}](%x) - return (%1); + %1 : int = prim::Constant[value=7]() + %2 : int = prim::Constant[value=1]() + %3 : Dynamic = aten::add(%x, %1, %2) + return (%3); } diff --git a/test/onnx/expect/TestOperators.test_batchnorm_training.expect b/test/onnx/expect/TestOperators.test_batchnorm_training.expect index 9bdadb572b7c03..24cdc2529af7bf 100644 --- a/test/onnx/expect/TestOperators.test_batchnorm_training.expect +++ b/test/onnx/expect/TestOperators.test_batchnorm_training.expect @@ -11,8 +11,8 @@ graph { output: "6" output: "7" output: "8" - output: "batch_norm_dead_output-9" - output: "batch_norm_dead_output-10" + output: "batch_norm_dead_output-13" + output: "batch_norm_dead_output-14" op_type: "BatchNormalization" attribute { name: "epsilon" diff --git a/test/onnx/test_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py similarity index 100% rename from test/onnx/test_caffe2.py rename to test/onnx/test_pytorch_onnx_caffe2.py diff --git a/test/test_jit.py b/test/test_jit.py index 544bea76f174e3..1459b1f70564d4 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -889,6 +889,7 @@ def fn(a, b): def test_alexnet(self): x = torch.ones(1, 3, 224, 224) trace, _ = torch.jit.get_trace_graph(torchvision.models.AlexNet(), x) + self.run_pass('cse', trace) self.assertExpectedGraph(trace) # Inplace copies don't work with tracer yet. @@ -1155,20 +1156,20 @@ def tanh(a): def test_batch_elementwise_binary(self): @torch.jit.batch(batch_size=4) - def add(a, b): - return a + b + def mul(a, b): + return a * b xs, batch = self.rand_batch(4, (True, 3), (False, 2)) xs2, batch2 = xs, batch - res_batch = add(batch, batch2) - res = [torch.add(xs[j], xs2[j]) for j in range(4)] + res_batch = mul(batch, batch2) + res = [torch.mul(xs[j], xs2[j]) for j in range(4)] self.assertEqual(res, res_batch.examples()) # test broadcast xs, batch = self.rand_batch(4, (False, 3), (False, 2)) b = torch.rand(3, 2) - res_batch = add(batch, b) - res = [torch.add(xs[j], b) for j in range(4)] + res_batch = mul(batch, b) + res = [torch.mul(xs[j], b) for j in range(4)] self.assertEqual(res, res_batch.examples()) def test_batch_mm(self): @@ -1231,6 +1232,7 @@ def where(c, a, b): res = [torch.where(xs_cond[j], xs[j], xs2[j]) for j in range(4)] self.assertEqual(res, res_batch.examples()) + @unittest.skip("Need support for scalar arguments") def test_lstm_cell(self): def LSTMCell(x, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c): i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index b31cd45ec47a6d..9050009f62c38e 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -134,13 +134,7 @@ PRE_RECORD_TRACE = CodeTemplate("""\ jit::tracer::PreTraceInfo trace_info; if (jit::tracer::isTracing()) { - trace_info = jit::tracer::preRecordTrace( jit::aten::${trace_name}, ${trace_inputs} ); - if (!jit::tracer::ArgumentStash::empty()) { - ${record_positional_attributes} - AT_ASSERT(jit::tracer::ArgumentStash::empty()); - } else { - ${record_attributes} - } + trace_info = jit::tracer::preRecordTrace(jit::aten::${trace_name}, ${trace_inputs}); } """) @@ -387,55 +381,8 @@ def emit_record_trace(env): if not should_trace(declaration): return ('', '') - # Note [clang-802.0.42 tuple overload bug] - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # Originally, my plan for emit_$ecord_trace was to keep it as - # simple as possible, if at the expense of some somewhat ugly - # overloads. So this meant we had a 'recordTrace' function - # with overloads like this: - # - # recordTrace(..., const Variable& out) - # recordTrace(..., const std::tuple& out) - # - # Unfortunately, this triggers a bug in clang-802.0.42 - # (widely used in macOS Sierra 10.12.6) wherein a Variable is - # implicitly convertible into a std::tuple; - # a minimal repro can be seen below here: - # - # #include - # struct T {}; - # void f(const std::tuple&) {} - # void g(T& x) { f(x); } - # - # To work around this bug, the code generator is a bit more - # complicated, and is taught how to handle this situation. - local = {} - - tensor_args = [arg for arg in declaration['arguments'] if arg['simple_type'] in {'Tensor', 'TensorList'}] - local['tensor_args'] = [arg['name'] for arg in tensor_args] - if any(arg['simple_type'] == 'TensorList' for arg in tensor_args): - # Allocate a temporary vector with flatten and pass it in - local['trace_inputs'] = CodeTemplate("flatten_tensor_args( $tensor_args )").substitute(local) - else: - local['trace_inputs'] = CodeTemplate("{ ${tensor_args} }").substitute(local) - - local['record_attributes'] = [] - for arg in declaration['arguments']: - if arg['simple_type'] in {'Tensor', 'TensorList'}: - continue - attr_name = RENAME_ATTRIBUTES.get((declaration['name'], arg['name']), arg['name']) - local['record_attributes'].append(RECORD_ATTRIBUTE.substitute(attr_name=attr_name, name=arg['name'])) - - local['record_positional_attributes'] = [] - for i, arg in enumerate(declaration['arguments']): - if arg['simple_type'] == 'Tensor': - continue - if arg['simple_type'] == 'TensorList': - local['record_positional_attributes'] = POSITIONAL_ATTR_NYI - break - local['record_positional_attributes'].append( - RECORD_POSITIONAL_ATTRIBUTE.substitute(name=arg['name'], i=i)) + local['trace_inputs'] = sum([['"{}"'.format(arg['name']), arg['name']] for arg in declaration['arguments']], []) # Record inplace operations as out-of-place operations (e.g., # not add_ but add) diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index d16c6d303f340b..2f1adf0ab59f4b 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -38,67 +38,6 @@ using namespace at; using namespace torch::autograd::generated; namespace torch { namespace autograd { -// Helper methods for working with Attributes (torch/csrc/jit/attributes.h) - -at::Tensor maybeUnwrapVar(const at::Tensor& t) { - return t.is_variable() ? Variable(t).data() : t; -} - -// The overloaded accessors are convenient for the generated code (since we -// don't want to make the codegen do the dispatch manually) -static void setattr(jit::Node* n, jit::Symbol name, int64_t v) { n->i_(name, v); } -static void setattr(jit::Node* n, jit::Symbol name, const at::Scalar& v) { n->t_(name, maybeUnwrapVar(v.toTensor())); } -static void setattr(jit::Node* n, jit::Symbol name, SparseTensorRef s) { n->t_(name, s.tref); } -static void setattr(jit::Node* n, jit::Symbol name, const at::IntList& v) { n->is_(name, v); } -static void setattr(jit::Node* n, jit::Symbol name, bool v) { n->i_(name, v); } -static void setattr(jit::Node* n, jit::Symbol name, double v) { n->f_(name, v); } -static void setattr(jit::Node* n, jit::Symbol name, std::string v) { n->s_(name, v); } -template -static void setattr(jit::Node* n, jit::Symbol name, std::array v) { n->is_(name, std::vector(v.begin(), v.end())); } - -static jit::Value* insertConstant(jit::Node* n, jit::IValue value) { - jit::WithInsertPoint guard(n); - return insertConstant(*n->owningGraph(), std::move(value)); -} - -static void genericInsertInput(jit::Node* n, size_t idx, jit::IValue value) { - n->insertInput(idx, insertConstant(n, std::move(value))); -} - -void failPositionalAttr() { - throw std::runtime_error("unsupported type in setposattr. File a bug report!"); -} - -static void setposattr(jit::Node* n, size_t idx, const char *name, int64_t v) { genericInsertInput(n, idx, v); } -static void setposattr(jit::Node* n, size_t idx, const char *name, const at::Scalar& v) { genericInsertInput(n, idx, v); } -static void setposattr(jit::Node* n, size_t idx, const char *name, SparseTensorRef s) { failPositionalAttr(); } -static void setposattr(jit::Node* n, size_t idx, const char *name, const at::IntList& v) { - using ArgumentStash = jit::tracer::ArgumentStash; - if (ArgumentStash::hasIntList(name)) { - auto info = ArgumentStash::popIntList(name); - for (size_t i = 0; i < info.size(); ++i) { - if (info[i] != nullptr) continue; - info[i] = insertConstant(n, v[i]); - } - for (jit::Value* v : info) { - if (*v->type() != *jit::IntType::get()) { - throw std::runtime_error( - "Type mismatch in setposattr for IntList. Check that your program " - "is valid without tracing, and please file a bug report if it is."); - } - } - jit::WithInsertPoint insert_point{n}; - auto& g = *n->owningGraph(); - auto size = g.insertNode(g.createList(jit::IntType::get(), info))->output(); - n->insertInput(idx, size); - } else { - return genericInsertInput(n, idx, v); - } -} -static void setposattr(jit::Node* n, size_t idx, const char *name, bool v) { genericInsertInput(n, idx, v); } -static void setposattr(jit::Node* n, size_t idx, const char *name, double v) { genericInsertInput(n, idx, v); } -template -static void setposattr(jit::Node* n, size_t idx, const char *name, std::array v) { failPositionalAttr(); } VariableType::VariableType(Context* context, Type* baseType) : Type(context, /*is_variable=*/true, /*is_undefined=*/false) diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py index 18c043a6c1061d..ad9ad2e05c4f4c 100644 --- a/tools/jit/gen_jit_dispatch.py +++ b/tools/jit/gen_jit_dispatch.py @@ -84,7 +84,7 @@ def from_attribute(arg): 'Scalar': '{}.toScalar()', 'ScalarType': 'static_cast({}.toInt())', 'Tensor': '{}.toTensor()', - 'bool': '{}.toInt()', + 'bool': 'bool({}.toInt())', 'double': '{}.toDouble()', 'int64_t': '{}.toInt()', 'std::array': 'as_bool_array<2>({}.toIntList()->elements())', diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp index b9dfa25d9ee870..f2165f4efa6d83 100644 --- a/torch/csrc/DynamicTypes.cpp +++ b/torch/csrc/DynamicTypes.cpp @@ -140,8 +140,7 @@ PyObject* createPyObject(const at::Storage& storage) bool isStorage(PyObject* obj) { - auto it = py_storage_type_to_attype.find(Py_TYPE(obj)); - return it != py_storage_type_to_attype.end(); + return py_storage_type_to_attype.count(Py_TYPE(obj)); } std::unique_ptr createStorage(PyObject* obj) { diff --git a/torch/csrc/Exceptions.cpp b/torch/csrc/Exceptions.cpp index 306dd3a70a4228..4de8a71591e406 100644 --- a/torch/csrc/Exceptions.cpp +++ b/torch/csrc/Exceptions.cpp @@ -109,6 +109,10 @@ static std::string formatMessage(const char *format, va_list fmt_args) { static const size_t ERROR_BUF_SIZE = 1024; char error_buf[ERROR_BUF_SIZE]; vsnprintf(error_buf, ERROR_BUF_SIZE, format, fmt_args); + + // Ensure that the string is null terminated + error_buf[sizeof(error_buf) / sizeof(*error_buf) - 1] = 0; + return std::string(error_buf); } diff --git a/torch/csrc/jit/argument_spec.h b/torch/csrc/jit/argument_spec.h index 69b5036766e998..d6bd90cb708784 100644 --- a/torch/csrc/jit/argument_spec.h +++ b/torch/csrc/jit/argument_spec.h @@ -153,7 +153,7 @@ struct ArgumentInfo { operator TypePtr() const { if(!defined()) return DynamicType::get(); - return std::make_shared(type(), device(), sizes(), strides()); + return TensorType::create(type(), device(), sizes(), strides()); } private: // offsetinto sizes_strides() array where the sizes start for tensor j diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index ceb379a53925d4..72d51bc2f304b9 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -1,7 +1,9 @@ #include "torch/csrc/jit/autodiff.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" +#include "torch/csrc/jit/passes/common_subexpression_elimination.h" #include "torch/csrc/jit/symbolic_variable.h" +#include "torch/csrc/jit/operator.h" #include "torch/csrc/utils/functional.h" #include @@ -13,36 +15,66 @@ namespace torch { namespace jit { using value_map = std::unordered_map; using value_set = std::unordered_set; -bool hasOneValuedInput(Node *n, torch::jit::Symbol name) { - auto maybe_t = n->get(name); - if (!maybe_t) return false; - return maybe_t->toDouble() == 1.0; +void wrapDim(int64_t & dim, const std::vector & sizes) { + if (dim < 0) { + dim += sizes.size(); + } } bool isDifferentiable(Node * n) { - static std::unordered_set differentiable_kinds = { - aten::add, aten::sub, aten::mul, prim::Constant, - aten::sigmoid, aten::tanh, aten::mm, aten::chunk, aten::split, aten::t, aten::neg, - aten::unsqueeze, aten::expand, aten::addmm, aten::gt, aten::lt, aten::eq, aten::ne, aten::ge, aten::le, aten::type_as, - aten::relu, aten::exp, prim::AutogradAdd + static OperatorSet differentiable_ops = { + "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", + "aten::add(Tensor self, Scalar other, *, Scalar alpha) -> Tensor", + "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", + "aten::sub(Tensor self, Scalar other, *, Scalar alpha) -> Tensor", + "aten::mul(Tensor self, Tensor other) -> Tensor", + "aten::mul(Tensor self, Scalar other) -> Tensor", + "aten::sigmoid(Tensor self) -> Tensor", + "aten::tanh(Tensor self) -> Tensor", + "aten::relu(Tensor self) -> Tensor", + "aten::exp(Tensor self) -> Tensor", + "aten::t(Tensor self) -> Tensor", + "aten::neg(Tensor self) -> Tensor", + "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]", + "aten::split(Tensor self, int split_size, int dim) -> Tensor[]", + "aten::type_as(Tensor self, Tensor other) -> Tensor", + "aten::unsqueeze(Tensor self, int dim) -> Tensor", + "aten::mm(Tensor self, Tensor mat2) -> Tensor", + "aten::lt(Tensor self, Tensor other) -> Tensor", + "aten::le(Tensor self, Tensor other) -> Tensor", + "aten::gt(Tensor self, Tensor other) -> Tensor", + "aten::ge(Tensor self, Tensor other) -> Tensor", + "aten::eq(Tensor self, Tensor other) -> Tensor", + "aten::ne(Tensor self, Tensor other) -> Tensor" }; - // TODO: check this more generally via schema - // This check ensures that the `alpha` and `beta` attributes on this addmm - // node are constant and equivalent to 1.0 - if (n->kind() == aten::addmm) { - if (n->inputs().size() > 3) - return false; - if (!hasOneValuedInput(n, attr::alpha) || !hasOneValuedInput(n, attr::beta)) - return false; - } - auto isTensor = [](Value* v) { return v->type()->isSubtypeOf(*DynamicType::get()); }; - if(!std::all_of(n->inputs().begin(), n->inputs().end(), isTensor) - || !std::all_of(n->outputs().begin(), n->outputs().end(), isTensor)) - return false; + if (n->kind() == prim::Constant || n->kind() == prim::AutogradAdd) + return true; + if (differentiable_ops.find(n)) + return true; - if (n->kind() == aten::type_as && !n->inputs().at(1)->isTensor()) { - return false; + if (n->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) { + return static_cast(n->input(1)->type()->cast()); + } + if (n->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor")) { + if (!n->is_constant(attr::dim)) return false; + for (Value * input : n->inputs().slice(0, n->inputs().size() - 1)) { + if (!input->type()->cast()) return false; + } + return true; + } + if (n->matches("aten::squeeze(Tensor self) -> Tensor")) { + return static_cast(n->input()->type()->cast()); + } + if (n->matches("aten::squeeze(Tensor self, int dim) -> Tensor")) { + return n->namedInput(attr::self)->type()->cast() && n->is_constant(attr::dim); + } + if (n->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor")) { + return n->is_constant(attr::size) && n->is_constant(attr::implicit); + } + if (n->matches("aten::view(Tensor self, int[] size) -> Tensor") || + n->matches("aten::reshape(Tensor self, int[] shape) -> Tensor")) { + return static_cast(n->namedInput(attr::self)->type()->cast()); } // linear blocks may appear as inputs to graph executors, but they are removed @@ -55,7 +87,7 @@ bool isDifferentiable(Node * n) { static_cast(isDifferentiable)); } - return differentiable_kinds.count(n->kind()) > 0; + return false; } @@ -83,146 +115,149 @@ bool outputRequiresGrad(Node* node, std::function requires_grad) { } } - - static std::vector gradientForNode(Node* node, ArrayRef grad_values) { const auto build_sym_grad = [node](const std::vector& grads) -> std::vector { auto inputs = fmap(node->inputs()); auto outputs = fmap(node->outputs()); - switch(node->kind()) { - case aten::add: - // TODO (apaszke): remove formulas for attributed nodes once they are removed - // o = self + alpha*other - if(inputs.size() == 1) { - return { grads.at(0) }; - } else if (node->hasAttribute(attr::alpha)) { - return {grads.at(0), grads.at(0) * at::Scalar(node->t(attr::alpha))}; - } else { - return {grads.at(0), nullptr, grads.at(0) * node->namedInput(attr::alpha)}; - } - case aten::sub: - // o = self - alpha*other - if(inputs.size() == 1) { - return {grads.at(0)}; - } else if (node->hasAttribute(attr::alpha)) { - return {grads.at(0), -grads.at(0) * at::Scalar(node->t(attr::alpha))}; - } else { - return {grads.at(0), nullptr, grads.at(0) * node->namedInput(attr::alpha)}; - } - case aten::mul: - // o = self * other - if(inputs.size() == 1) - return {grads.at(0) * at::Scalar(node->t(attr::other))}; - else - return {grads.at(0) * inputs.at(1), grads.at(0) * inputs.at(0)}; - case prim::Constant: - return {}; - case aten::sigmoid: - return {grads.at(0) * outputs.at(0) * (1 - outputs.at(0))}; - case aten::tanh: - return {grads.at(0) * (1 - outputs.at(0) * outputs.at(0))}; - case aten::relu: - return {grads.at(0) * (outputs.at(0) > at::Scalar(0)).type_as(outputs.at(0))}; - case aten::exp: - return {grads.at(0) * (outputs.at(0))}; - case aten::chunk: - case aten::split: - return {SymbolicVariable::cat(grads, node->namedInput(attr::dim))}; - case aten::t: - return {grads.at(0).t()}; - case aten::neg: - return {-grads.at(0)}; - case aten::view: - // TODO: if sizes are not available statically, add an operator that reutrns them as a tuple - return {grads.at(0).view(inputs.at(0).sizes())}; - case aten::type_as: - return {grads.at(0).type_as(inputs.at(0))}; - case aten::unsqueeze: - return {grads.at(0).squeeze(node->namedInput(attr::dim))}; - case aten::mm: { - SymbolicVariable dmat1, dmat2; - if (auto type = inputs.at(0).value()->type()->cast()) { - auto sizes = type->sizes(), strides = type->strides(); - if (strides.at(0) == 1 && strides.at(1) == sizes.at(0)) { - dmat1 = inputs.at(1).mm(grads.at(0).t()).t(); - } else { - dmat1 = grads.at(0).mm(inputs.at(1).t()); - } + + if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") || + node->matches("aten::add(Tensor self, Scalar other, *, Scalar alpha) -> Tensor")) { + return {grads.at(0), grads.at(0) * node->namedInput(attr::alpha), nullptr}; + + } else if (node->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") || + node->matches("aten::sub(Tensor self, Scalar other, *, Scalar alpha) -> Tensor")) { + return {grads.at(0), -grads.at(0) * node->namedInput(attr::alpha), nullptr}; + + } else if (node->matches("aten::mul(Tensor self, Tensor other) -> Tensor") || + node->matches("aten::mul(Tensor self, Scalar other) -> Tensor")) { + return {grads.at(0) * inputs.at(1), grads.at(0) * inputs.at(0)}; + + } else if (node->matches("aten::sigmoid(Tensor self) -> Tensor")) { + return {grads.at(0) * outputs.at(0) * (1 - outputs.at(0))}; + + } else if (node->matches("aten::tanh(Tensor self) -> Tensor")) { + return {grads.at(0) * (1 - outputs.at(0) * outputs.at(0))}; + + } else if (node->matches("aten::relu(Tensor self) -> Tensor")) { + return {grads.at(0) * (outputs.at(0) > at::Scalar(0)).type_as(outputs.at(0))}; + + } else if (node->matches("aten::exp(Tensor self) -> Tensor")) { + return {grads.at(0) * (outputs.at(0))}; + + } else if (node->matches("aten::t(Tensor self) -> Tensor")) { + return {grads.at(0).t()}; + + } else if (node->matches("aten::neg(Tensor self) -> Tensor")) { + return {-grads.at(0)}; + + } else if (node->matches("aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]") || + node->matches("aten::split(Tensor self, int split_size, int dim) -> Tensor[]")) { + return {SymbolicVariable::cat(grads, node->namedInput(attr::dim)), nullptr, nullptr}; + + } else if (node->matches("aten::view(Tensor self, int[] size) -> Tensor") || + node->matches("aten::reshape(Tensor self, int[] shape) -> Tensor")) { + // TODO: if sizes are not available statically, add an operator that reutrns them as a tuple + auto sizes = node->namedInput(attr::self)->type()->expect()->sizes(); + return {grads.at(0).reshape(sizes), nullptr}; + + } else if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) { + return {grads.at(0).type_as(inputs.at(0)), nullptr}; + + } else if (node->matches("aten::unsqueeze(Tensor self, int dim) -> Tensor")) { + return {grads.at(0).squeeze(node->namedInput(attr::dim)), nullptr}; + + } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) { + SymbolicVariable dmat1, dmat2; + if (auto type = inputs.at(0).value()->type()->cast()) { + auto sizes = type->sizes(), strides = type->strides(); + if (strides.at(0) == 1 && strides.at(1) == sizes.at(0)) { + dmat1 = inputs.at(1).mm(grads.at(0).t()).t(); } else { dmat1 = grads.at(0).mm(inputs.at(1).t()); } - if (auto type = inputs.at(1).value()->type()->cast()) { - auto sizes = type->sizes(), strides = type->strides(); - if (strides.at(0) == 1 && strides.at(1) == sizes.at(0)) { - dmat2 = grads.at(0).t().mm(inputs.at(0)).t(); - } else { - dmat2 = inputs.at(0).t().mm(grads.at(0)); - } + } else { + dmat1 = grads.at(0).mm(inputs.at(1).t()); + } + if (auto type = inputs.at(1).value()->type()->cast()) { + auto sizes = type->sizes(), strides = type->strides(); + if (strides.at(0) == 1 && strides.at(1) == sizes.at(0)) { + dmat2 = grads.at(0).t().mm(inputs.at(0)).t(); } else { dmat2 = inputs.at(0).t().mm(grads.at(0)); } - return {dmat1, dmat2}; + } else { + dmat2 = inputs.at(0).t().mm(grads.at(0)); } - case aten::expand: { - const auto& input_sizes = inputs.at(0).sizes(); - if (input_sizes.size() == 0) - return {grads.at(0).sum()}; - auto grad_sizes = node->get>(attr::size).value(); - auto grad = grads.at(0); - while (grad_sizes.size() > input_sizes.size()) { - grad = grad.sum(0, false); - grad_sizes.erase(grad_sizes.begin()); - } - for (size_t i = 0; i < input_sizes.size(); ++i) { - if (input_sizes[i] == 1 && grad_sizes[i] > 1) { - grad = grad.sum(i, true); - } - } - return {grad}; + return {dmat1, dmat2}; + + } else if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor")) { + const auto& input_sizes = inputs.at(0).sizes(); + if (input_sizes.size() == 0) + return {grads.at(0).sum(), nullptr, nullptr}; + auto grad_sizes = node->get>(attr::size).value(); + auto grad = grads.at(0); + while (grad_sizes.size() > input_sizes.size()) { + grad = grad.sum(0, false); + grad_sizes.erase(grad_sizes.begin()); } - case aten::squeeze: { - const auto& sizes = inputs.at(0).sizes(); - // TODO (apaszke): need to select the right overload here - if (node->hasAttribute(attr::dim)) { - int dim = node->i(attr::dim); - return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim)}; - } else { - std::vector squeezed_dims; - for (size_t i = 0; i < sizes.size(); ++i) { - if (sizes[i] != 1) continue; - squeezed_dims.push_back(i); - } - SymbolicVariable returned_grad = grads.at(0); - for (auto it = squeezed_dims.rbegin(); it != squeezed_dims.rend(); ++it) - returned_grad = returned_grad.unsqueeze(*it); - return {returned_grad}; + for (size_t i = 0; i < input_sizes.size(); ++i) { + if (input_sizes[i] == 1 && grad_sizes[i] > 1) { + grad = grad.sum(i, true); } } - case aten::cat: { - int dim = node->get(attr::dim).value(); - const auto& first_sizes = inputs.at(0).sizes(); - const auto has_first_sizes = [&first_sizes](SymbolicVariable var) { - return var.sizes() == first_sizes; - }; - // TODO (apaszke): This will need an adjustment for the dim argument - // NB: this is a specialization for the common case where all inputs are - // of equal sizes. We can use a single split operation to handle that. - if (std::all_of(inputs.begin(), inputs.end(), has_first_sizes)) { - return grads.at(0).chunk(inputs.size(), dim); - } else { - size_t offset = 0; - auto grad = grads.at(0); - std::vector returned_grads; - for (auto input : inputs) { - returned_grads.push_back(grad.narrow(dim, offset, input.sizes()[dim])); - offset += input.sizes()[dim]; - } - return returned_grads; + return {grad, nullptr, nullptr}; + + } else if (node->matches("aten::squeeze(Tensor self) -> Tensor")) { + const auto& sizes = inputs.at(0).sizes(); + std::vector squeezed_dims; + for (size_t i = 0; i < sizes.size(); ++i) { + if (sizes[i] != 1) continue; + squeezed_dims.push_back(i); + } + SymbolicVariable returned_grad = grads.at(0); + for (auto it = squeezed_dims.begin(); it != squeezed_dims.end(); ++it) + returned_grad = returned_grad.unsqueeze(*it); + return {returned_grad}; + + } else if (node->matches("aten::squeeze(Tensor self, int dim) -> Tensor", /*const=*/attr::dim)) { + int64_t dim = *node->get(attr::dim); + const auto& sizes = inputs.at(0).sizes(); + wrapDim(dim, sizes); + if (sizes.size() == 0) { + return {grads.at(0), nullptr}; + } + return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim), nullptr}; + + } else if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor", /*const=*/attr::dim)) { + int dim = *node->get(attr::dim); + auto tensor_inputs = inputs; tensor_inputs.pop_back(); + const auto& first_sizes = tensor_inputs.at(0).sizes(); + const auto has_first_sizes = [&first_sizes](SymbolicVariable var) { + return var.sizes() == first_sizes; + }; + + // NB: this is a specialization for the common case where all inputs are + // of equal sizes. We can use a single split operation to handle that. + if (std::all_of(tensor_inputs.begin(), tensor_inputs.end(), has_first_sizes)) { + auto tensor_grads = grads.at(0).chunk(tensor_inputs.size(), dim); + tensor_grads.push_back(nullptr); // for attr::dim + return tensor_grads; + } else { + size_t offset = 0; + auto grad = grads.at(0); + std::vector tensor_grads; + for (auto input : tensor_inputs) { + tensor_grads.push_back(grad.narrow(dim, offset, input.sizes()[dim])); + offset += input.sizes()[dim]; } + tensor_grads.push_back(nullptr); // for attr::dim + return tensor_grads; } + + } else if (node->kind() == prim::Constant) { + return {}; } - throw std::runtime_error(std::string("don't support differentiation of `") + - node->kind().toDisplayString() + "`"); + throw std::runtime_error(std::string("failed to differentiate `") + node->kind().toDisplayString() + "`"); }; if (!isDifferentiable(node)) { throw std::runtime_error(std::string("differentiation of ") + node->kind().toDisplayString() + " " @@ -273,15 +308,13 @@ static std::vector linearGradientForNode(Node* node, ArrayRef gr // to make reading gradient graphs easier, remember the name of the forward op linear->s_(attr::name, node->kind().toDisplayString()); auto block = linear->addBlock(); - { - WithInsertPoint guard(block); - auto results = gradientForNode(node, grad_values); - for(auto r : results) { - block->registerOutput(r); - linear->addOutput()->copyMetadata(r); - } - } - return linear->outputs(); + WithInsertPoint guard(block); + auto results = gradientForNode(node, grad_values); + return fmap(results, [block, linear](Value *grad) -> Value* { + if (!grad) return nullptr; + block->registerOutput(grad); + return linear->addOutput()->copyMetadata(grad); + }); } struct ReverseDetails { @@ -377,6 +410,40 @@ static ReverseDetails addReverseInline(Gradient& grad_desc, return ReverseDetails(std::move(grad_map), std::move(requires_grad_set), reverse_block); } +// Any temporary value from the primal graphs needs to be captured for later use in the +// reverse graph, to avoid costly recomputations. However, a lot of the nodes we have +// in our graphs are simply constants, which are cheap to execute and replicate, and so +// it's better to just copy them into the reverse graph, without polluting the output +// lists unnecessarily. +static void liftConstants(Gradient& grad_desc, ReverseDetails& rev_info) { + static const auto err = [](Value*) -> Value* { + throw std::runtime_error("unexpected input"); + }; + auto & graph = *grad_desc.f; + Block* reverse_block = rev_info.reverse_block; + + for (Node *top_node : reverse_block->nodes()) { + JIT_ASSERT(top_node->kind() == prim::GradOf || + top_node->kind() == prim::AutogradAdd || + top_node->kind() == prim::Undefined); + if (top_node->kind() != prim::GradOf) continue; + Block * grad_body = top_node->blocks().at(0); + for (Node *node : grad_body->nodes()) { + for (Value * input : node->inputs()) { + if (input->node()->kind() != prim::Constant) continue; + if (input->node()->owningBlock() == grad_body) continue; + Node *lifted_constant = graph.createClone(input->node(), err); + reverse_block->prependNode(lifted_constant); + node->replaceInputWith(input, lifted_constant->output()); + } + } + } + + // It's possible the we've cloned the same constants many times, + // so we use CSE to deduplicate them. + EliminateCommonSubexpression(reverse_block); +} + // Takes a grad_desc.f returned from `addReverseInline` and splits off the // reverse_block into its own graph, storing it in df. // All intermediates needed in the second stage are added to @@ -516,6 +583,8 @@ Gradient differentiate(std::shared_ptr& _graph, const std::vector& WithInsertPoint guard(grad_desc.f->block()); // Fills in df_input_vjps and df_output_vjps auto rev_info = addReverseInline(grad_desc, requires_grad); + // Lift constants captured for the reverse graph into it + liftConstants(grad_desc, rev_info); // addReverseInline has to call gradientForNode if *any* of the outputs // require grad, but it will emit vjps for *all* outputs. Use DCE to remove // unnecessary nodes. diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp index 1c8bf928aab5dd..3c4ad0c130ea31 100644 --- a/torch/csrc/jit/constants.cpp +++ b/torch/csrc/jit/constants.cpp @@ -38,14 +38,14 @@ RegisterOperators reg({ prim::Constant, [](Node* node) -> Operation { TypePtr type = node->output()->type(); - if(type->isSubtypeOf(*DynamicType::get())) { + if(type->isSubtypeOf(DynamicType::get())) { auto t = autograd::make_variable(node->t(attr::value)); return [t](Stack& stack) { stack.push_back(t); return 0; }; } else if ( - type->isSubtypeOf(*NumberType::get()) && + type->isSubtypeOf(NumberType::get()) && node->kindOf(attr::value) == AttributeKind::i) { auto i = node->i(attr::value); return [i](Stack& stack) { @@ -53,14 +53,14 @@ RegisterOperators reg({ return 0; }; } else if ( - type->isSubtypeOf(*NumberType::get()) && + type->isSubtypeOf(NumberType::get()) && node->kindOf(attr::value) == AttributeKind::f) { auto f = node->f(attr::value); return [f](Stack& stack) { push(stack, f); return 0; }; - } else if(type->isSubtypeOf(*ListType::ofInts())) { + } else if(type->isSubtypeOf(ListType::ofInts())) { auto is = node->is(attr::value); return [is](Stack& stack) { push(stack, is); diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index 90120f1be0fb95..71dec999c40216 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -156,7 +156,7 @@ void addAttribute(onnx::NodeProto * n_p, jit::Node * n, jit::Symbol name, Export void encodeTypeProtoTensorType(onnx::TypeProtoTensor* tensor_type, Value* n) { onnx::TensorShapeProto* shape = tensor_type->mutable_shape(); - if (TensorType* node_type = n->type()->cast()) { + if (TensorTypePtr node_type = n->type()->cast()) { const std::vector& sizes = node_type->sizes(); for (std::int64_t s : sizes) { shape->add_dim(s); diff --git a/torch/csrc/jit/fusion_compiler.cpp b/torch/csrc/jit/fusion_compiler.cpp index 7e0db38e5b5614..8d20045efefe6a 100644 --- a/torch/csrc/jit/fusion_compiler.cpp +++ b/torch/csrc/jit/fusion_compiler.cpp @@ -4,6 +4,7 @@ #include "torch/csrc/jit/ir.h" #include "torch/csrc/jit/code_template.h" #include "torch/csrc/jit/resource_guard.h" +#include "torch/csrc/jit/constants.h" #include "torch/csrc/utils/disallow_copy.h" #include "torch/csrc/variable_tensor_functions.h" @@ -196,15 +197,14 @@ static std::string valueName(Value * n) { return "n" + std::to_string(n->unique()); } -static std::string scalarValue(const at::Tensor & t) { - auto s = at::Scalar(t); - if (s.isIntegral()){ - return std::to_string(s.toLong()); - } else { - std::ostringstream out; - out << std::scientific << s.toDouble() << "f"; - return out.str(); - } +static std::string scalarValue(int64_t v) { + return std::to_string(v); +} + +static std::string scalarValue(double v) { + std::ostringstream out; + out << std::scientific << v << "f"; + return out.str(); } static const char * scalarTypeName(at::ScalarType type) { @@ -280,42 +280,31 @@ std::string encodeRHS(Node * n) { {aten::remainder, "remainderf(${0}, ${1})"}, {aten::pow, "powf(${0}, ${1})"}, - //alpha - {aten::add, "${0} + ${alpha}*${1}"}, - {aten::sub, "(${0} - ${alpha}*${1})"}, - - // special - {aten::lerp, "${0} + ${weight}*(${1} - ${0})"}, - {aten::clamp, "min(max(${0},${min}),${max})"}, + // binary with alpha + {aten::add, "${0} + ${2}*${1}"}, + {aten::sub, "(${0} - ${2}*${1})"}, // simple derivatives {aten::_sigmoid_backward, "${0} * ${1} * (1.f - ${1})"}, {aten::_tanh_backward, "${0} * (1.f - ${1} * ${1})"}, }; + if (n->kind() == prim::Constant) { + auto val = toIValue(n->output()).value(); + if (val.isDouble()) { + return scalarValue(val.toDouble()); + } else { + JIT_ASSERT(val.isInt()); + return scalarValue(val.toInt()); + } + } TemplateEnv env; size_t i = 0; for(auto in : n->inputs()) { env.s(std::to_string(i++), valueName(in)); } - // TODO (apaszke): remove once we get rid of attributes - // ops like div have a / b or a / 2 with the constant having the attribute other - // so we add other as an input if it is present - // 'pow' is the same but uses exponent as the attribute, so we handle that here as well - if(n->hasAttribute(attr::other) || n->hasAttribute(attr::exponent)) { - env.s(std::to_string(i), scalarValue(n->t(attr::other))); - } - // we also add any other scalar tensors to the env for special ops - for(auto a : n->attributeNames()) { - if(n->kindOf(a) == AttributeKind::t) { - auto v = n->t(a); - if(v.dim() == 0) { - JIT_ASSERT(a.is_attr()); - env.s(a.toUnqualString(), scalarValue(v)); - } - } - } + const auto & str = simple_map_ops.at(n->kind()); return format(str, env); } @@ -362,9 +351,12 @@ std::vector emitCompilationUnit(std::ostream & out, flat_output_nodes.push_back(o); } else { auto cat = o->node(); - size_t nInputs = cat->inputs().size(); + auto tensor_inputs = cat->inputs(); + // We need to drop the dim arg + tensor_inputs = tensor_inputs.slice(0, tensor_inputs.size() - 1); + size_t nInputs = tensor_inputs.size(); concat_desc.emplace_back(desc, nInputs, cat->get(attr::dim).value()); - for(auto c : cat->inputs()) { + for(auto c : tensor_inputs) { emitFormal(c, *concat_desc.back().subtensorDesc); flat_output_nodes.push_back(c); } diff --git a/torch/csrc/jit/fusion_compiler.h b/torch/csrc/jit/fusion_compiler.h index 969cc1fc05566e..6c4759aefb692a 100644 --- a/torch/csrc/jit/fusion_compiler.h +++ b/torch/csrc/jit/fusion_compiler.h @@ -29,7 +29,7 @@ struct TensorDesc { : TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {} TensorDesc(const at::Tensor& t) : TensorDesc(t.type().scalarType(), t.sizes(), t.strides()) {} - TensorDesc(TensorType *type) + TensorDesc(TensorTypePtr type) : TensorDesc(type->scalarType(), type->sizes(), type->strides()) {} // number of dimensions after contiguity compression diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index 2c595ffd679c27..df81c378ad137d 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -388,8 +388,10 @@ struct GraphExecutorImpl { auto graph_ = graph->copy(); runRequiredPasses(graph_); if(optimize) { - if(!symbolically_differentiable) + if(!symbolically_differentiable) { + EraseShapeInformation(*graph_); CreateAutodiffSubgraphs(*graph_); + } runOptimization(graph_, /*graphMustSupportVariables=*/true); } autograd_fallback_graph = graph_; diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp index 2d3af265a5d651..d54b0434e7e64a 100644 --- a/torch/csrc/jit/import.cpp +++ b/torch/csrc/jit/import.cpp @@ -512,8 +512,46 @@ std::shared_ptr buildGraph(const Graph_& graph_, std::vector& return graph; } +// TODO: this should be removed once we'll be able to serialize value types +void reconstructOutputTypes(Block *b) { + for (Node * n : b->nodes()) { + if (n->kind() == prim::Constant) { + switch (n->kindOf(attr::value)) { + case AttributeKind::i: + n->output()->setType(IntType::get()); + break; + case AttributeKind::f: + n->output()->setType(FloatType::get()); + break; + case AttributeKind::is: + n->output()->setType(ListType::ofInts()); + break; + case AttributeKind::t: + n->output()->setType(DynamicType::get()); + break; + default: + throw std::runtime_error("Unsupported case in reconstructOutputTypes. File a bug report"); + } + } else if (n->kind() == prim::ListConstruct && n->inputs().size() > 0) { + auto input_types = fmap(n->inputs(), [](Value *v) -> TypePtr { + return v->node()->kind() == prim::Constant ? v->type() : nullptr; + }); + // Check that all types are equal + if (std::equal(std::next(input_types.begin()), input_types.end(), input_types.begin())) { + auto elem_type = input_types[0]; + if (elem_type == IntType::get()) { + n->output()->setType(ListType::ofInts()); + } + } + } + for (Block * b : n->blocks()) { + reconstructOutputTypes(b); + } + } } +} // anonymous namespace + std::shared_ptr ImportIRGraph(const std::string& serialized_graph, std::vector& initializers) { @@ -523,6 +561,8 @@ std::shared_ptr ImportIRGraph(const std::string& serialized_graph, auto graph = buildGraph(model.graph, initializers); + reconstructOutputTypes(graph->block()); + return graph; } diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index cf7dda32413c23..65bdcf695f6de2 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -147,6 +147,7 @@ static std::vector> flattenStages(Graph & graph) { while(input_pos < graph.inputs().size() && graph.inputs()[input_pos]->stage() == i) { auto nv = store->addOutput(); auto old_node = graph.inputs()[input_pos]; + nv->setType(old_node->type()); stage_input_types[i].push_back(old_node->type()); old_node->replaceAllUsesWith(nv); input_pos++; diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 3cd1d46b7df2af..7f09b22b324d11 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -240,7 +240,7 @@ static void checkSameDevice(const Node* node) { bool has_device = false; int device; auto checkValue = [&](const Value* v) { - if(TensorType* type = v->type()->cast()) { + if(TensorTypePtr type = v->type()->cast()) { if(!has_device) { has_device = true; device = type->device(); @@ -596,7 +596,7 @@ at::optional Node::get(Symbol name) const { // disambiguate via schema at::Tensor ten = t(name); const Argument* arg = findArgument(schema(), name).second; - if(arg->type->isSubtypeOf(*NumberType::get())) { + if(arg->type->isSubtypeOf(NumberType::get())) { return IValue(at::Scalar(ten)); } return IValue(ten); @@ -619,7 +619,23 @@ Value* Node::namedInput(Symbol name) const { // so this is completely unsafe and needs to be gone as soon as possible. return v; } - return input(findArgument(schema(), name).first); + const auto & the_schema = schema(); + int64_t tensor_list_pos = 0; + for (auto & arg : the_schema.arguments) { + if (*arg.type == *ListType::ofTensors()) + break; + tensor_list_pos++; + } + int64_t arg_pos = findArgument(schema(), name).first; + // XXX: we don't have a single value we could give for a Tensor[], + // because we flatten lists into arguments + JIT_ASSERT(arg_pos != tensor_list_pos); + // NB: if there's no tensor list, then tensor_list_pos == arguments.size(), so this is always true + if (arg_pos < tensor_list_pos) { + return input(arg_pos); + } else { + return input(inputs().size() - (the_schema.arguments.size() - arg_pos)); + } } bool Node::matches(const char *signature_literal, at::ArrayRef const_inputs) { diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 8c940626699cd7..9af468e6ee06e7 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -181,7 +181,7 @@ struct Value { public: Value* setType(const TypePtr type); void inferTypeFrom(const at::Tensor& output) { - setType(std::make_shared(output)); + setType(TensorType::create(output)); } const TypePtr & type() const { JIT_ASSERT(type_ != nullptr); @@ -995,13 +995,13 @@ friend struct Block; } Node* createTuple(at::ArrayRef values) { auto types = fmap(values, [](Value* v) { return v->type(); }); - auto tt = std::make_shared(std::move(types)); + auto tt = TupleType::create(std::move(types)); auto n = create(prim::TupleConstruct, values); n->output()->setType(tt); return n; } Node* createTupleUnpack(Value * v) { - TupleType* tt = v->type()->expect(); + TupleTypePtr tt = v->type()->expect(); auto n = create(prim::TupleUnpack, {v}, 0); for(auto & element : tt->elements()) { n->addOutput()->setType(element); @@ -1011,9 +1011,9 @@ friend struct Block; Node* createList(const TypePtr& elem_type, at::ArrayRef values) { auto n = create(prim::ListConstruct, values); for(const auto & v : values) { - JIT_ASSERT(v->type()->isSubtypeOf(*elem_type)); + JIT_ASSERT(v->type()->isSubtypeOf(elem_type)); } - n->output()->setType(std::make_shared(elem_type)); + n->output()->setType(ListType::create(elem_type)); return n; } Node* createNumToTensor(Value* value) { diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp index 26e314c53eaa3a..19da2195e5b33d 100644 --- a/torch/csrc/jit/operator.cpp +++ b/torch/csrc/jit/operator.cpp @@ -65,7 +65,7 @@ struct SchemaParser { void parseType(Argument& arg) { arg.type = parseBaseType(); if(L.nextIf('[')) { - arg.type = std::make_shared(arg.type); + arg.type = ListType::create(arg.type); if(L.cur().kind == TK_NUMBER) { arg.N = std::stoll(L.next().text()); } @@ -210,8 +210,8 @@ struct SchemaParser { Lexer L; bool kwarg_only; }; -} +} // namespace script namespace { @@ -271,7 +271,7 @@ struct OperatorRegistry { operators_by_sig[canonicalSchemaString(op.schema)] = op_ptr; } - Operator& lookupByLiteral(const char * name) { + const std::shared_ptr& lookupByLiteral(const char * name) { auto it = operators_by_sig_literal.find(name); if (it == operators_by_sig_literal.end()) { auto op_ptr_it = operators_by_sig.find(name); @@ -286,7 +286,7 @@ struct OperatorRegistry { JIT_ASSERTM(op_ptr_it != operators_by_sig.end(), "Couldn't find an operator for %s", name); it = operators_by_sig_literal.emplace_hint(it, name, op_ptr_it->second); } - return *it->second; + return it->second; } const std::vector>& getOperators(Symbol name) { @@ -315,7 +315,7 @@ const std::vector>& getAllOperatorsFor(Symbol name) { } Operator& sig(const char *signature) { - return getRegistry().lookupByLiteral(signature); + return *getRegistry().lookupByLiteral(signature); } FunctionSchema parseSchema(const std::string& schema) { @@ -328,7 +328,7 @@ at::optional attributeKindOf(TypePtr type) { case TypeKind::FloatType: return AttributeKind::f; case TypeKind::NumberType: return AttributeKind::t; case TypeKind::ListType: - if(type->isSubtypeOf(*ListType::ofInts())) + if(type->isSubtypeOf(ListType::ofInts())) return AttributeKind::is; else return at::nullopt; @@ -338,7 +338,7 @@ at::optional attributeKindOf(TypePtr type) { } bool typeMatches(TypePtr actual, TypePtr formal) { - return actual->isSubtypeOf(*formal); + return actual->isSubtypeOf(formal); } bool Operator::matches(const Node* node) const { @@ -431,4 +431,26 @@ const Operator& getOperatorFor(const Node* node) { throw er; } + +OperatorSet::OperatorSet(std::initializer_list sig_literals) { + auto & registry = getRegistry(); + for (const char * sig : sig_literals) { + auto op = registry.lookupByLiteral(sig); + ops[Symbol::fromQualString(op->schema.name)].push_back(op); + } +} + +Operator* OperatorSet::find(Node *n) { + auto it = ops.find(n->kind()); + if (it == ops.end()) { + return nullptr; + } + for (auto & op : it->second) { + if (op->matches(n)) { + return op.get(); + } + } + return nullptr; +} + }} diff --git a/torch/csrc/jit/operator.h b/torch/csrc/jit/operator.h index 47ed788770f1cb..7e6a314d2cb8c3 100644 --- a/torch/csrc/jit/operator.h +++ b/torch/csrc/jit/operator.h @@ -38,6 +38,7 @@ struct TORCH_API Operator { // as attributes or inputs. This function returns the right Operation function, // given a node encoded for one variant. // Behavior is undefined if matches(n) == false + // TODO (apaszke) : remove Operation selectVariant(Node* n) const { if(n->hasAttributes()) { JIT_ASSERT(op_const_attributes != nullptr); @@ -77,4 +78,13 @@ struct TORCH_API RegisterOperators { } }; +struct OperatorSet { + OperatorSet(std::initializer_list sig_literals); + // XXX: Returns a nullptr if no Operator in the set matches n + Operator* find(Node *n); +private: + std::unordered_map>> ops; +}; + + }} diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.h b/torch/csrc/jit/passes/common_subexpression_elimination.h index 64ae4f6bd9ca8b..f74f0868eb7a88 100644 --- a/torch/csrc/jit/passes/common_subexpression_elimination.h +++ b/torch/csrc/jit/passes/common_subexpression_elimination.h @@ -5,5 +5,6 @@ namespace torch { namespace jit { TORCH_API void EliminateCommonSubexpression(std::shared_ptr& graph); +TORCH_API void EliminateCommonSubexpression(Block * block); }} diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp index 28b1195efd5273..d37ff6dfea5b43 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp @@ -35,10 +35,7 @@ void mergeNodes(Block * block, Symbol group_node_kind, ArrayRef nodes) { value_map[v] = nv; return nv; }; - std::unordered_set group_set; - for(auto n : nodes) { - group_set.insert(n); - } + std::unordered_set group_set(nodes.begin(), nodes.end()); for(auto n : nodes) { auto nn = new_graph->appendNode(new_graph->createClone(n, getOrCreateInput)); for(size_t i = 0; i < nn->outputs().size(); ++i) { diff --git a/torch/csrc/jit/passes/erase_number_types.cpp b/torch/csrc/jit/passes/erase_number_types.cpp index 91f08c0941e7c2..0892b3e6cbdfe3 100644 --- a/torch/csrc/jit/passes/erase_number_types.cpp +++ b/torch/csrc/jit/passes/erase_number_types.cpp @@ -13,7 +13,7 @@ static void EraseNumberTypesOnBlock(Block* block) { case prim::Constant: { // remove primitive constants, replacing with tensor equivalent // ONNX does not support non-tensor constants - if(it->output()->type()->isSubtypeOf(*NumberType::get())) { + if(it->output()->type()->isSubtypeOf(NumberType::get())) { auto s = *constant_as(it->output()); WithInsertPoint guard(*it); Value* r = insertConstant(*block->owningGraph(), s.toTensor()); @@ -27,7 +27,7 @@ static void EraseNumberTypesOnBlock(Block* block) { } break; default: { for(auto o : it->outputs()) { - if (o->type()->isSubtypeOf(*NumberType::get())) { + if (o->type()->isSubtypeOf(NumberType::get())) { o->setType(TensorType::fromNumberType(o->type())); } } diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 660a4ac4e8ad38..e0395ffdaadeae 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -32,7 +32,6 @@ std::unordered_set simple_mappable = { aten::atan, aten::atan2, aten::ceil, - aten::clamp, aten::cos, aten::cosh, aten::div, @@ -45,7 +44,6 @@ std::unordered_set simple_mappable = { aten::ge, aten::gt, aten::le, - aten::lerp, aten::lgamma, aten::log, aten::log10, @@ -74,33 +72,17 @@ std::unordered_set simple_mappable = { aten::type_as, aten::_sigmoid_backward, aten::_tanh_backward, + // TODO support those + //aten::clamp, + //aten::lerp, }; bool isSimpleMap(Node *node) { + // TODO: use signature matching if(simple_mappable.count(node->kind()) == 0) return false; if((node->kind() == aten::min || node->kind() == aten::max) && node->inputs().size() == 1) return false; - // Make sure that the node doesn't broadcast. - JIT_ASSERT(node->inputs().size() > 0); - TensorType* expected_type = node->inputs()[0]->type()->cast(); - if (!expected_type) return false; -//type checking is intentionally dropped from isSimpleMap -//isFusable is checking input/output types as there are some exceptions from allFloatIO requirement - static const auto equal_modulo_strides = [](TensorType* expected, const TypePtr& _actual) { - TensorType* actual = _actual->cast(); - return actual && - expected->device() == actual->device() && - expected->sizes() == actual->sizes(); - }; - for (Value * val : node->inputs()) { - if (!equal_modulo_strides(expected_type, val->type())) - return false; - } - for (Value * val : node->outputs()) { - if (!equal_modulo_strides(expected_type, val->type())) - return false; - } return true; } @@ -133,10 +115,8 @@ struct GraphFuser { bool hasSupportedType(Value* node) { if (auto tt = node->type()->cast()) { if (tt->scalarType() == at::kFloat) return true; - #ifdef USE_CUDA // Checks for half tensor on GPU - // const auto device = tt->device(); if (tt->device() != kCPUDevice && CUDA_VERSION >= 9 && tt->scalarType() == at::ScalarType::Half) { @@ -144,57 +124,74 @@ struct GraphFuser { } #endif } - return false; } - bool allSupportedList(at::ArrayRef list){ - for (auto& o: list){ - if (!hasSupportedType(o)) return false; + bool areTensorsOfSameShape(at::ArrayRef values) { + auto expected_type = values.at(0)->type()->cast(); + if (!expected_type) return false; + for (Value * val : values) { + auto val_type = val->type()->cast(); + if (!val_type) return false; + if (expected_type->device() != val_type->device()) return false; + if (expected_type->sizes() != val_type->sizes()) return false; } - return true; } - bool allSupportedIO(Node* node) { - return (allSupportedList(node->inputs()) && allSupportedList(node->outputs())); + bool hasSupportedType(Node* node) { + return areTensorsOfSameShape(node->inputs()) && + haveSupportedType(node->inputs()) && + haveSupportedType(node->outputs()); + } + + bool haveSupportedType(at::ArrayRef list) { + for (Value *v : list) { + if (!hasSupportedType(v)) return false; + } + return true; } + bool isFusable(Node * node) { if (node->owningBlock() != block) return false; if (node->kind() == prim::FusionGroup) return true; if (!isSimpleMap(node)) return false; - switch (node->kind()){ -//comparison operators produce Byte type, and it's ok, check only inputs - case aten::le: - case aten::ge: - case aten::lt: - case aten::gt: - case aten::ne: - case aten::eq: - return allSupportedList(node->inputs()); - case aten::type_as: -//type_as can have different input types as long as output is float, check only output - return allSupportedList(node->outputs()); - default: - return allSupportedIO(node); + + if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", /*const=*/attr::alpha)) { + std::vector inputs {node->namedInput(attr::self), node->namedInput(attr::other)}; + return areTensorsOfSameShape(inputs) && haveSupportedType(inputs); + } else if (node->matches("aten::add(Tensor self, Scalar other, *, Scalar alpha) -> Tensor", /*const=*/{attr::other, attr::alpha})) { + return hasSupportedType(node->namedInput(attr::self)); + } else if (node->matches("aten::lt(Tensor self, Tensor other) -> Tensor") || + node->matches("aten::le(Tensor self, Tensor other) -> Tensor") || + node->matches("aten::gt(Tensor self, Tensor other) -> Tensor") || + node->matches("aten::ge(Tensor self, Tensor other) -> Tensor") || + node->matches("aten::eq(Tensor self, Tensor other) -> Tensor") || + node->matches("aten::ne(Tensor self, Tensor other) -> Tensor")) { + // comparison operators produce Byte type, and it's ok, check only inputs + return areTensorsOfSameShape(node->inputs()) && haveSupportedType(node->inputs()); + } else if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) { + // type_as can have different input types as long as output is float, check only output + return haveSupportedType(node->outputs()); + } else { + return hasSupportedType(node); } } - bool allOutputsHaveSameSize(Node * node) { - TensorType *tt_ptr = nullptr; - for (const auto i : node->inputs()) { - auto cur_tt_ptr = i->type()->cast(); - if (!cur_tt_ptr) { - return false; - } - - if (tt_ptr && tt_ptr->sizes() != cur_tt_ptr->sizes()) { - return false; - } - tt_ptr = cur_tt_ptr; + bool allCatInputsHaveSameSize(Node * node) { + JIT_ASSERT(node->kind() == aten::cat); + std::vector inputs = node->inputs(); + if (!node->hasAttributes()) { + inputs.pop_back(); // Get rid of the dim argument } - return true; + + auto expected = inputs.at(0)->type()->cast(); + if (!expected) return false; + return std::all_of(inputs.begin(), inputs.end(), [expected](Value *v) { + auto actual = v->type()->cast(); + return actual && actual->sizes() == expected->sizes(); + }); } // Can this node produce an _output_ of a fusion group? @@ -207,7 +204,7 @@ struct GraphFuser { // this concat fusion only works when all the inputs are the same size // and we can statically infer the dimension along which we should concat // otherwise they cannot partipate in the same map - if(node->kind() == aten::cat && node->get(attr::dim) && allOutputsHaveSameSize(node)) + if(node->kind() == aten::cat && node->is_constant(attr::dim) && allCatInputsHaveSameSize(node)) return true; return false; @@ -327,12 +324,24 @@ struct GraphFuser { inputs_map[input] = subgraph.inputs()[i++]; } // add n's inputs to the fusion group's input list if we don't already have them + Node * insert_after = nullptr; for (auto input : n->inputs()) { if (inputs_map.count(input) == 0) { - auto in_group = subgraph.addInput(); - in_group->setType(input->type()); - inputs_map[input] = in_group; - group->addInput(input); + if (input->type()->isSubtypeOf(DynamicType::get())) { + auto in_group = subgraph.addInput(); + in_group->setType(input->type()); + inputs_map[input] = in_group; + group->addInput(input); + } else { + // We don't support passing in scalars as arguments to fused kernels, so we generally + // don't allow fusing tensor-scalar operations unless the scalar is constant. In those + // cases we inline the constants directly in the body of the fused group. + JIT_ASSERT(input->node()->kind() == prim::Constant); + Node * in_const = subgraph.createClone(input->node(), [](Value*) -> Value* { throw std::runtime_error("unexpected input"); }); + subgraph.prependNode(in_const); + insert_after = in_const; + inputs_map[input] = in_const->output(); + } } } // copy n into the graph, remapping its inputs to internal nodes @@ -351,7 +360,7 @@ struct GraphFuser { subgraph.inputs()[p]->replaceAllUsesWith(in_graph->output()); subgraph.eraseInput(p); } - return subgraph.prependNode(in_graph); + return insert_after ? in_graph->insertAfter(insert_after) : subgraph.prependNode(in_graph); } // turn consumer node n into a fusion group with just n inside @@ -432,7 +441,7 @@ struct GraphFuser { if (!isChunk(chunk)) return false; // and the thing being chunked is fusable into the consumer - Value * producer_for_chunk = chunk->input(); + Value * producer_for_chunk = chunk->namedInput(attr::self); if (!isFusable(producer_for_chunk->node()) || !allUsersAreThisConsumer(chunk,producer_for_chunk)) return false; // and all uses of the chunk are in this consumer @@ -457,20 +466,25 @@ struct GraphFuser { std::vector> chunked_inputs; for (auto input : producer_for_chunk_node->inputs()) { auto input_type = input->type()->cast(); + // XXX: we only work with pointwise ops in here, so we know it is valid to push + // the concat only through tensor arguments (and all other args can be safely ignored). + if (!input_type) + continue; // NB: I decided not to use cloneFrom here, because if we make cloneFrom // copy selects one day, it is definitely not what you want here (selects // have different types). // TODO: Perhaps we should use cloneFrom now, as it seems unlikely // to copy select nodes now that we have refactored to have a Value // distinct from Node. - Node * input_chunk = block->owningGraph()->create(chunk->kind(), 0); - input_chunk->copyAttributes(*chunk); + Node * input_chunk = block->owningGraph()->create(aten::chunk, 0); input_chunk->addInput(input); + input_chunk->addInput(chunk->namedInput(attr::chunks)); + input_chunk->addInput(chunk->namedInput(attr::dim)); insertAt(&insertion_point, input_chunk); chunked_inputs.emplace_back(); // alas, to not be C++17 for (auto chunk_sel : chunk->outputs()) { - auto chunk_sel_type = chunk_sel->type()->cast(); + auto chunk_sel_type = chunk_sel->type()->expect(); Value * input_chunk_sel = input_chunk->addOutput(); input_chunk_sel->setType( input_type->withSizesStrides(chunk_sel_type->sizes(), @@ -482,12 +496,20 @@ struct GraphFuser { // apply the op to each chunk of the chunked operands, // and then rewrite the graph to use them! for (auto chunk_sel : chunk->outputs()) { + auto original_inputs = producer_for_chunk_node->inputs(); Node * chunked_op = block->owningGraph()->create(producer_for_chunk_node->kind()); chunked_op->copyAttributes(*producer_for_chunk_node); // Invariant: mappable operators always produce contiguous output chunked_op->output()->setType(chunk_sel->type()->cast()->contiguous()); - for (auto by_chunk_output_idx : chunked_inputs) { - chunked_op->addInput(by_chunk_output_idx.at(chunk_sel->offset())); + auto chunked_inputs_it = chunked_inputs.begin(); + for (size_t i = 0; i < original_inputs.size(); ++i) { + if (original_inputs[i]->type()->isSubtypeOf(DynamicType::get())) { + JIT_ASSERT(chunked_inputs_it != chunked_inputs.end()); + chunked_op->addInput(chunked_inputs_it->at(chunk_sel->offset())); + ++chunked_inputs_it; + } else { + chunked_op->addInput(original_inputs[i]); + } } insertAt(&insertion_point, chunked_op); chunk_sel->replaceAllUsesWith(chunked_op->output()); diff --git a/torch/csrc/jit/passes/lower_tuples.cpp b/torch/csrc/jit/passes/lower_tuples.cpp index 34f8c56f5607fe..89c74d4cf1fa7c 100644 --- a/torch/csrc/jit/passes/lower_tuples.cpp +++ b/torch/csrc/jit/passes/lower_tuples.cpp @@ -43,7 +43,7 @@ static void VisitNode(Node* n, Node* insert_point) { // flatten the input list op(a, tup, b) --> op(a, t0, t1, b) for(size_t i = 0; i < n->inputs().size();) { auto input = n->inputs()[i]; - if(TupleType* tt = input->type()->cast()) { + if(TupleTypePtr tt = input->type()->cast()) { JIT_ASSERTM(white_list.count(n->kind()) > 0, "tuple appears in op that does not forward tuples"); JIT_ASSERTM(input->node()->kind() == prim::TupleConstruct, "tuple use not matched to tuple construct"); for(size_t j = 0; j < tt->elements().size(); ++j) { @@ -68,7 +68,7 @@ static void VisitNode(Node* n, Node* insert_point) { // and: // tup = (t0, t1) // is placed at the current insertion point - if(TupleType* tt = output->type()->cast()) { + if(TupleTypePtr tt = output->type()->cast()) { JIT_ASSERTM(white_list.count(n->kind()) > 0, "tuple appears in op that does not forward tuples"); for(size_t j = 0; j < tt->elements().size(); j++) { n->insertOutput(i + 1 + j)->setType(tt->elements()[j]); diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index 030b4dc7a34395..75fb063c761a31 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -1,5 +1,6 @@ #include "torch/csrc/utils/pybind.h" #include "torch/csrc/jit/passes/onnx.h" +#include "torch/csrc/jit/passes/dead_code_elimination.h" #include "torch/csrc/autograd/function.h" #include "torch/csrc/autograd/symbolic.h" #include "torch/csrc/jit/assertions.h" @@ -194,6 +195,7 @@ void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExpo // Copy stage from original graph ctx.block->owningGraph()->setStage(old_block->owningGraph()->stage()); + EliminateDeadCode(ctx.block); } }} diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 4620dd3812d56b..7fcd47f3a23b54 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -96,13 +96,18 @@ void fuseBroadcast(Block *b) { JIT_ASSERT(!n->hasAttribute(attr::axis)); auto input_index = n->inputs().size() - 1; - auto* expanded_rhs = n->input(input_index)->node(); - - // The expanded_rhs input isn't actually an expand, so no fusion available - if (expanded_rhs->kind() != aten::expand) continue; - if (expanded_rhs->inputs().size() != 1) continue; + auto* rhs_expand = n->input(input_index)->node(); + + // The rhs_expand input isn't actually an expand, so no fusion available + // XXX: we can't use the ->matches(...) mechanism in here, because input nodes + // have been + if (rhs_expand->kind() != aten::expand || + rhs_expand->input(1)->node()->kind() != onnx::Constant || + rhs_expand->input(2)->node()->kind() != onnx::Constant) { + continue; + } - auto* unexpanded_rhs = expanded_rhs->input(); + auto* unexpanded_rhs = rhs_expand->input(0); // We need to know what the type pre-expand is. We should basically // always have this information (because expands are only ever traced, @@ -113,7 +118,7 @@ void fuseBroadcast(Block *b) { // Not all broadcasts are supported by ONNX broadcast. at::optional axis = fusibleExpandTo( unexpanded_rhs->type()->expect()->sizes(), // from - expanded_rhs->output()->type()->expect()->sizes()); // to + rhs_expand->output()->type()->expect()->sizes()); // to if (axis == at::nullopt) continue; @@ -128,8 +133,8 @@ void fuseBroadcast(Block *b) { n->i_(attr::axis, axis.value()); } } - if (!expanded_rhs->hasUses()) { - expanded_rhs->destroy(); + if (!rhs_expand->hasUses()) { + rhs_expand->destroy(); } } } @@ -265,13 +270,13 @@ void pushPackingPastRnn(Block *b) { // unhygenic way, Pytorch ends up propagating an incorrect type. // Until a long-term cleanup comes around, we can fix this by // resetting the size to the correct value. - TensorType* oldType = rnn->inputs()[0]->type()->cast(); + TensorTypePtr oldType = rnn->inputs()[0]->type()->cast(); if (oldType) { std::vector new_sizes; new_sizes.push_back(oldType->sizes()[0]); new_sizes.push_back(oldType->sizes()[1]); new_sizes.push_back(rnn->i(attr::hidden_size)); - TensorTypePtr newType = std::make_shared( + TensorTypePtr newType = TensorType::create( oldType->scalarType(), oldType->device(), new_sizes); next->outputs()[0]->setType(newType); } diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index 2ee777aee0e66f..ac0c96232647e9 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -30,7 +30,7 @@ void PeepholeOptimize(Block * block) { if (auto input_type = node->namedInput(attr::self)->type()->cast()) { auto expanded_sizes = node->get>(attr::size); if (expanded_sizes == input_type->sizes()) { - node->output()->replaceAllUsesWith(node->input()); + node->output()->replaceAllUsesWith(node->namedInput(attr::self)); } } } else if (node->matches("aten::t(Tensor self) -> Tensor")) { diff --git a/torch/csrc/jit/passes/remove_expands.cpp b/torch/csrc/jit/passes/remove_expands.cpp index f0f591cac59ec9..93d53e54819bbd 100644 --- a/torch/csrc/jit/passes/remove_expands.cpp +++ b/torch/csrc/jit/passes/remove_expands.cpp @@ -9,7 +9,7 @@ static void RemoveExpands(Block* block) { for (auto sub : it->blocks()) RemoveExpands(sub); if (it->kind() == aten::expand && it->get(attr::implicit) != static_cast(0)) { - it->output()->replaceAllUsesWith(it->input()); + it->output()->replaceAllUsesWith(it->namedInput(attr::self)); it.destroyCurrent(); } } diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 3b18699f94ffcd..f1fef4c5247ea0 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -42,12 +42,12 @@ IValue representativeValue(Value* v) { if(auto iv = toIValue(v)) { return *iv; } - if (TensorType* type = type_->cast()) { + if (TensorTypePtr type = type_->cast()) { auto backend = type->device() == -1 ? at::kCPU : at::kCUDA; at::DeviceGuard device_guard(type->device()); auto& attype = at::getType(backend, type->scalarType()); return attype.tensor(type->sizes(), type->strides()).zero_(); - } else if (type_->isSubtypeOf(*FloatType::get())) { + } else if (type_->isSubtypeOf(FloatType::get())) { return 0.f; } // we should not get here because isValidArgumentForRunning should have @@ -63,8 +63,8 @@ void PropagateShapeOnBlock(Block * block, bool insert_expands=true); // for each node in the schema with type Tensor, extract the TensorType // returns at::nullopt if any Tensor in the schema does not have a known shape // ignores non-tensor in the list of inputs -at::optional> gatherTensorTypes(Node *node) { - std::vector tensor_types; +at::optional> gatherTensorTypes(Node *node) { + std::vector tensor_types; auto & schema = node->schema(); auto & args = schema.arguments; @@ -75,12 +75,12 @@ at::optional> gatherTensorTypes(Node *node) { size_t input_i = 0; for (auto& arg : args) { size_t consume_n; // how many tensors do we check for in the input list - if (arg.type->isSubtypeOf(*ListType::ofTensors())) { + if (arg.type->isSubtypeOf(ListType::ofTensors())) { // we have a list of tensor, there is only ever one list // so we calculte how many elements must be in it by how much bigger // or smaller the input list is compared to the arguments in the schema consume_n = node->inputs().size() + 1 - args.size(); - } else if (arg.type->isSubtypeOf(*DynamicType::get())) { + } else if (arg.type->isSubtypeOf(DynamicType::get())) { // a single Tensor for this argument consume_n = 1; } else { @@ -89,7 +89,7 @@ at::optional> gatherTensorTypes(Node *node) { } for(size_t j = 0; j < consume_n; j++) { // bail out if a tensor does not have a size - TensorType *type = node->input(input_i++)->type()->cast(); + TensorTypePtr type = node->input(input_i++)->type()->cast(); if (!type) return at::nullopt; tensor_types.push_back(type); @@ -117,16 +117,18 @@ bool mergeTypes(ArrayRef lhs, ArrayRef rhs, ArrayRef out void PropagateShapeOnNode(Node * node, bool insert_expands=true); -void broadcastBinary(Node *node, std::vector& types, size_t idx1, size_t idx2) { +void broadcastBinary(Node *node, std::vector& types, size_t idx1, size_t idx2) { auto expected_size = at::infer_size(types[idx1]->sizes(), types[idx2]->sizes()); auto broadcast = [&](size_t input_idx) { - TensorType* input_type = types.at(input_idx); + TensorTypePtr input_type = types.at(input_idx); if (input_type->sizes() == expected_size) return; auto graph = node->owningGraph(); - Node *expand = graph->create(aten::expand, {node->inputs().at(input_idx)}) - ->is_(attr::size, expected_size) - ->i_(attr::implicit, 0) + WithInsertPoint point_guard { node }; + Node *expand = graph->create(aten::expand, + {node->inputs().at(input_idx), + insertConstant(*graph, expected_size), + insertConstant(*graph, 0)}) ->insertBefore(node); PropagateShapeOnNode(expand); node->replaceInput(input_idx, expand->output()); @@ -178,14 +180,14 @@ bool isValidArgumentForRunning(Value* v) { // allow constants if(toIValue(v)) return true; - if(TensorType* tt = v->type()->cast()) { + if(TensorTypePtr tt = v->type()->cast()) { return !at::isIntegralType(tt->scalarType()); } - return v->type()->isSubtypeOf(*FloatType::get()); + return v->type()->isSubtypeOf(FloatType::get()); } bool isValidReturnForRunning(Value* v) { - return v->type()->isSubtypeOf(*DynamicType::get()) || - v->type()->isSubtypeOf(*NumberType::get()); + return v->type()->isSubtypeOf(DynamicType::get()) || + v->type()->isSubtypeOf(NumberType::get()); } bool canPropagateShapeByRunningIt(Node* node) { @@ -244,7 +246,7 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) { case prim::NumToTensor: return; // correct num type is already set case prim::Constant: { - if(node->output()->type()->isSubtypeOf(*DynamicType::get())) { + if(node->output()->type()->isSubtypeOf(DynamicType::get())) { node->output()->inferTypeFrom(node->t(attr::value)); } return; @@ -280,6 +282,7 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) { node->matches("aten::min(Tensor self, Tensor other) -> Tensor") || node->matches("aten::max(Tensor self, Tensor other) -> Tensor") || node->matches("aten::lt(Tensor self, Tensor other) -> Tensor") || + node->matches("aten::le(Tensor self, Tensor other) -> Tensor") || node->matches("aten::gt(Tensor self, Tensor other) -> Tensor") || node->matches("aten::ge(Tensor self, Tensor other) -> Tensor") || node->matches("aten::eq(Tensor self, Tensor other) -> Tensor") || @@ -296,7 +299,7 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) { auto lhs_type = tensor_types.at(0); auto rhs_type = tensor_types.at(1); SHAPE_ASSERT(lhs_type->sizes().size() == 2 && rhs_type->sizes().size() == 2); - node->output()->setType(std::make_shared( + node->output()->setType(TensorType::create( lhs_type->scalarType(), lhs_type->device(), at::IntList{lhs_type->sizes().at(0), rhs_type->sizes().at(1)})); return; @@ -419,7 +422,7 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) { std::vector dim_vec = {(int64_t)tensor_types.at(0)->sizes().size()}; at::IntList dims(dim_vec); node->output()->setType( - std::make_shared(at::kLong, -1, dims)); + TensorType::create(at::kLong, -1, dims)); return; } else if (node->kind() == onnx::Reshape) { setUnshapedType(node); @@ -451,7 +454,8 @@ void PropagateShapeOnBlock(Block * block, bool insert_expands) { } } -} +} // anonymous namespace + void PropagateInputShapes(Graph & graph, const ArgumentSpec & spec) { JIT_ASSERT(graph.inputs().size() == spec.size()); for(size_t i = 0; i < spec.size(); ++i) { @@ -462,4 +466,29 @@ void PropagateInputShapes(Graph & graph, const ArgumentSpec & spec) { PropagateShapeOnBlock(graph.block()); } +namespace { + +void EraseShapeInformation(at::ArrayRef vals) { + for (Value * v : vals) { + v->setType(unshapedType(v->type())); + } +} + +void EraseShapeInformation(Block * b) { + EraseShapeInformation(b->inputs()); + EraseShapeInformation(b->outputs()); + for (Node * n : b->nodes()) { + EraseShapeInformation(n->outputs()); + for (Block *sb : n->blocks()) { + EraseShapeInformation(sb); + } + } +} + +} // anonymous namespace + +void EraseShapeInformation(Graph & graph) { + EraseShapeInformation(graph.block()); +} + }} diff --git a/torch/csrc/jit/passes/shape_analysis.h b/torch/csrc/jit/passes/shape_analysis.h index 1b38cbbe5739a4..199d376e87ddec 100644 --- a/torch/csrc/jit/passes/shape_analysis.h +++ b/torch/csrc/jit/passes/shape_analysis.h @@ -3,8 +3,11 @@ #include "torch/csrc/WindowsTorchApiMacro.h" namespace torch { namespace jit { + struct Graph; struct ArgumentSpec; + +void EraseShapeInformation(Graph & graph); TORCH_API void PropagateInputShapes(Graph & graph, const ArgumentSpec & spec); }} diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index c9e41e8a7eee26..81211085569953 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -281,7 +281,10 @@ void initPythonIRBindings(PyObject * module_) { #undef VS - py::class_>(m, "Block"); + py::class_>(m, "Block") + .def("nodes",[](Block &b) { + return py::make_iterator(b.nodes().begin(), b.nodes().end()); + }); #define NS(name) \ def(#name,&Node :: name) @@ -451,9 +454,9 @@ void initPythonIRBindings(PyObject * module_) { ; py::class_>(m, "DynamicType") - .def(py::init<>()); + .def(py::init([](){ return DynamicType::create(); })); py::class_>(m, "TupleType") - .def(py::init>()) + .def(py::init([](std::vector a){ return TupleType::create(a); })) .def("elements", [](TupleType &self){ std::vector types; for (auto type : self.elements()) { @@ -461,6 +464,8 @@ void initPythonIRBindings(PyObject * module_) { } return types; }); + py::class_>(m, "ListType") + .def_static("ofInts", &ListType::ofInts); py::class_(m,"Use") .def_readonly("user",&Use::user) diff --git a/torch/csrc/jit/python_tracer.cpp b/torch/csrc/jit/python_tracer.cpp index 78247017b9e9c6..7439b2b5e334cc 100644 --- a/torch/csrc/jit/python_tracer.cpp +++ b/torch/csrc/jit/python_tracer.cpp @@ -77,12 +77,25 @@ PreTraceInfo preRecordPythonTrace(THPObjectPtr pyobj, if(!apply) { throw python_error(); } - return makePreTraceInfo(inputs, [&](const std::shared_ptr& state, Graph& graph) { - return graph.createPythonOp( - std::move(apply), - arg_types, - std::move(scalar_args)); - }); + + PreTraceInfo info; + auto & state = getTracingState(); + auto & graph = state->graph; + + Node *n = info.n = graph->createPythonOp( + std::move(apply), + arg_types, + std::move(scalar_args)); + recordSourceLocation(n); + + for (const Variable & input : inputs) { + n->addInput(getValueTrace(input)); + } + + // NB: Order matters. This must append after inputs but before outputs. + graph->appendNode(n); + + return info; } void pythonRecordSourceLocation(Node* n) { diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 010e0919f9cd03..b1fa7dc4c4185f 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -202,7 +202,7 @@ RegisterOperators reg({ prim::ListConstruct, [](Node* node) -> Operation { size_t num_inputs = node->inputs().size(); - ListType* lt = node->output()->type()->expect(); + ListTypePtr lt = node->output()->type()->expect(); if(IntType::get() == lt->getElementType()) { return [=](Stack& stack) { auto inputs = peekSlice(stack, 0, num_inputs, num_inputs); diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index be819775f1dc94..6cf7b37d4f43c0 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -93,9 +93,9 @@ struct CastValue : public SugaredValue { throw ErrorReport(loc) << "expected a single argument for cast"; auto values = toValues(inputs); Value* input = values.at(0); - if(!input->type()->isSubtypeOf(*type)) { + if(!input->type()->isSubtypeOf(type)) { if(*type == *DynamicType::get()) { - if(!input->type()->isSubtypeOf(*NumberType::get())) { + if(!input->type()->isSubtypeOf(NumberType::get())) { throw ErrorReport(loc) << "expected a number"; } input = numToTensor(loc, input); @@ -149,8 +149,9 @@ struct Environment { std::shared_ptr next; SugaredValuePtr findInThisFrame(const std::string& name) { - if (value_table.count(name)) { - return value_table.at(name); + auto it = value_table.find(name); + if (it != value_table.end()) { + return it->second; } return nullptr; } @@ -244,7 +245,7 @@ struct Environment { throw ErrorReport(loc) << "Cannot re-assign '" << name << "' because it has type " << value->kind() << " and " << name << " is not a first-class value. Only reassignments to first-class values are allowed"; } - if(!as_simple_value->type()->isSubtypeOf(*unshapedType(simple_parent->type()))) { + if(!as_simple_value->type()->isSubtypeOf(unshapedType(simple_parent->type()))) { throw ErrorReport(loc) << "variable '" << name << "' previously has type " << simple_parent->type()->str() << " but is now being assigned to a value of type " << as_simple_value->type()->str(); } @@ -368,7 +369,7 @@ Value* createStack(Graph& g, const SourceRange& loc, at::ArrayRef inputs } static bool isTensorSubtype(Value* v) { - return v->type()->isSubtypeOf(*DynamicType::get()); + return v->type()->isSubtypeOf(DynamicType::get()); } at::optional> getIntListAttribute(at::optional N, Value* input) { @@ -385,72 +386,6 @@ at::optional> getIntListAttribute(at::optional N, return std::vector(*N, *r); } -// try to turn constant inputs into attributes -void liftConstantAttributes(const FunctionSchema& schema, Node* node) { - // we shouldn't start with attributes, just inputs - JIT_ASSERT(!node->hasAttributes()); - std::vector new_inputs; - Attributes attributes; - for(size_t i = 0, n = 0; i < schema.arguments.size(); ++i) { - const auto& arg = schema.arguments[i]; - // this was a builtin with a vararg list lowered, - if(*arg.type == *ListType::ofTensors()) { - // we need to skip all the vararg nodes, and continue parsing the - // possible attribute nodes - size_t vararg_list_size = node->inputs().size() - (schema.arguments.size() - 1); - while(n < i + vararg_list_size) { - new_inputs.push_back(node->input(n++)); - } - continue; - } - auto input = node->input(n++); - switch(arg.type->kind()) { - case TypeKind::IntType:{ - auto r = constant_as(input); - if(!r) - return; - attributes.i_(Symbol::attr(arg.name), *r); - } break; - case TypeKind::FloatType: { - auto r = constant_as(input); - if(!r) - return; - attributes.f_(Symbol::attr(arg.name), *r); - } break; - case TypeKind::NumberType: { - auto r = constant_as(input); - if(!r) - return; - attributes.t_(Symbol::attr(arg.name), r->toTensor()); - } break; - case TypeKind::ListType: { - auto elem = arg.type->expect()->getElementType(); - if(elem->kind() == TypeKind::IntType) { - auto r = getIntListAttribute(arg.N, input); - if(!r) - return; - attributes.is_(Symbol::attr(arg.name), *r); - } else { - // only IntLists can become attributes, other - // types are not attribute-able - new_inputs.push_back(input); - } - } break; - default: - new_inputs.push_back(input); - } - } - // nothing changed no need to modify the node - if(!attributes.hasAttributes()) - return; - - node->removeAllInputs(); - for(Value* input : new_inputs) { - node->addInput(input); - } - node->copyAttributes(attributes); -} - at::ArrayRef createTupleUnpack(Value* v) { // small peephole optimization to ensure IntList attributes can still turn // into constants e.g. in x.expand([3, 4]) @@ -515,10 +450,8 @@ at::optional> tryMatchSchema( return at::nullopt; } positional_inputs[i] = NamedValue( - loc, - i, - insertConstant(graph, *default_value, loc) - ->setType(schema.arguments[i].type)); + loc, i, + insertConstant(graph, *default_value, loc)); } // check input types @@ -538,12 +471,12 @@ at::optional> tryMatchSchema( // Allow tuples that only contain integers to turn into lists of integers if(*ListType::ofInts() == *arg.type && v.value->type()->kind() == TypeKind::TupleType && - v.value->type()->isSubtypeOf(*ListType::ofInts())) { + v.value->type()->isSubtypeOf(ListType::ofInts())) { auto unpacked = createTupleUnpack(v.value); v.value = graph.insertNode(graph.createList(IntType::get(), unpacked))->output(); } - if(!v.value->type()->isSubtypeOf(*arg.type)) { + if(!v.value->type()->isSubtypeOf(arg.type)) { err() << "expected a value of type " << arg.type->str() << " for argument '" << arg.name << "' but found " << v.value->type()->str() << "\n" << v.loc; @@ -551,7 +484,7 @@ at::optional> tryMatchSchema( } // we only support tensor lists for builtins, where they must be flattened - if(arg.type->isSubtypeOf(*ListType::ofTensors())) { + if(arg.type->isSubtypeOf(ListType::ofTensors())) { auto outputs = createTupleUnpack(v.value); flat_inputs.insert(flat_inputs.end(), outputs.begin(), outputs.end()); } else { @@ -578,10 +511,6 @@ static std::shared_ptr tryEmitBuiltin( return nullptr; // we successfully matched this schema, construct the node - // note: we always construct purely positional nodes here - // the pass liftConstantAttributes replaces the node with with one that - // uses attributes if all the attributes ended up as constants - NodeKind kind(Symbol::aten(name)); auto n = graph->insertNode(graph->create(kind, *flat_inputs, 0)) ->setSourceLocation(std::make_shared(loc)); @@ -602,9 +531,6 @@ static std::shared_ptr tryEmitBuiltin( } } - if(op->hasAttributedVersion()) - liftConstantAttributes(op->schema, n); - // assert that we did indeed create an op that has implementation // otherwise schema and dispatch are not in sync getOperation(n); @@ -663,7 +589,7 @@ static Value* ensureTensor(const SourceRange& range, Value* v) { } static Value* ensureInt(const SourceRange& range, Value* v) { - if(!v->type()->isSubtypeOf(*IntType::get())) { + if(!v->type()->isSubtypeOf(IntType::get())) { throw ErrorReport(range) << "expected a int but found a " << v->type()->str(); } @@ -778,7 +704,7 @@ struct to_ir { auto range = return_stmt.range(); size_t return_type_idx = 0; for (auto& r : results) { - if(r->type()->isSubtypeOf(*NumberType::get())) { + if(r->type()->isSubtypeOf(NumberType::get())) { graph->registerOutput(numToTensor(range, r)); } else { ensureTensor(range, r); @@ -787,7 +713,7 @@ struct to_ir { TypePtr type = DynamicType::get(); if (typed_def.schema) { type = typed_def.schema->returns.at(return_type_idx).type; - if (!r->type()->isSubtypeOf(*type)) { + if (!r->type()->isSubtypeOf(type)) { throw ErrorReport(return_stmt.range()) << "Return value at position " << return_type_idx << " was annotated as having type " << type->str() << " but is actually of type " << r->type()->str(); @@ -914,10 +840,10 @@ struct to_ir { Value* emitCond(Expr cond) { Value* v = emitExpr(cond, identity); - if(v->type()->isSubtypeOf(*DynamicType::get())) { + if(v->type()->isSubtypeOf(DynamicType::get())) { v = tensorToNum(cond.range(), v, IntType::get()); } - if(!v->type()->isSubtypeOf(*IntType::get())) { + if(!v->type()->isSubtypeOf(IntType::get())) { throw ErrorReport(cond) << "expected a tensor or integer expression for condition but found " << v->type()->str(); } return v; diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 576344427c0461..cb7893234dc747 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -86,7 +86,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { << "of arguments: expected " << arguments.size() << ", but got " << inputs.size(); for (size_t i = 0; i < arguments.size(); ++i) { - if (!inputs[i]->type()->isSubtypeOf(*arguments[i])) + if (!inputs[i]->type()->isSubtypeOf(arguments[i])) throw ErrorReport(loc) << "type mismatch at argument " << i << ": expected " << arguments[i]->str() << ", but got " << inputs[i]->type()->str(); } @@ -135,7 +135,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { // equivalent, but the PythonOp impl ends with an optional tuple unpack, so we need // to do it. for (auto & ret_type_elem : returns) { - if (!ret_type_elem->isSubtypeOf(*DynamicType::get())) { + if (!ret_type_elem->isSubtypeOf(DynamicType::get())) { throw ErrorReport(loc) << "Python functions can currently only return Tensors"; } } diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index 76518aaf1d26fa..1120d0bcaad740 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -119,14 +119,14 @@ struct Method { for (size_t i=0; i < retval->inputs().size(); ++i) { auto scalar_type = inputs[i].type().scalarType(); auto sizes = inputs[i].sizes(); - auto type = std::make_shared(scalar_type, -1, sizes); + auto type = torch::jit::TensorType::create(scalar_type, -1, sizes); retval->inputs()[i]->setType(type); } JIT_ASSERT(retval->outputs().size() == outputs.size()); for (size_t i=0; i < retval->outputs().size(); ++i) { auto scalar_type = outputs[i].type().scalarType(); auto sizes = outputs[i].sizes(); - auto type = std::make_shared(scalar_type, -1, sizes); + auto type = torch::jit::TensorType::create(scalar_type, -1, sizes); retval->outputs()[i]->setType(type); } return retval; diff --git a/torch/csrc/jit/symbolic_variable.h b/torch/csrc/jit/symbolic_variable.h index ff9e5149068a81..e4d2f98ba0ea0f 100644 --- a/torch/csrc/jit/symbolic_variable.h +++ b/torch/csrc/jit/symbolic_variable.h @@ -1,6 +1,7 @@ #pragma once #include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/constants.h" namespace torch { namespace jit { @@ -56,90 +57,51 @@ struct SymbolicVariable { return create(aten::mul, {*this, rhs})[0].typeLike(*this); } SymbolicVariable operator*(at::Scalar rhs) const { - if(isConstInt(rhs, 1)) + if (isConstInt(rhs, 1)) return *this; - Node * n; - auto r = create(aten::mul, {*this}, 1, &n)[0]; - n->t_(attr::other, rhs.toTensor()); - return r; + return (*this) * insertConstant(rhs); } SymbolicVariable operator>(at::Scalar rhs) const { - Node * n; - auto r = create(aten::gt, {*this}, 1, &n)[0].typeLikeWithScalarType(*this, at::kByte); - n->t_(attr::other, rhs.toTensor()); - return r; + return create(aten::gt, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte); } SymbolicVariable operator<(at::Scalar rhs) const { - Node * n; - auto r = create(aten::lt, {*this}, 1, &n)[0].typeLikeWithScalarType(*this, at::kByte); - n->t_(attr::other, rhs.toTensor()); - return r; + return create(aten::lt, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte); } SymbolicVariable operator>=(at::Scalar rhs) const { - Node * n; - auto r = create(aten::ge, {*this}, 1, &n)[0].typeLikeWithScalarType(*this, at::kByte); - n->t_(attr::other, rhs.toTensor()); - return r; + return create(aten::ge, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte); } SymbolicVariable operator<=(at::Scalar rhs) const { - Node * n; - auto r = create(aten::le, {*this}, 1, &n)[0].typeLikeWithScalarType(*this, at::kByte); - n->t_(attr::other, rhs.toTensor()); - return r; + return create(aten::le, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte); } SymbolicVariable operator==(at::Scalar rhs) const { - Node * n; - auto r = create(aten::eq, {*this}, 1, &n)[0].typeLikeWithScalarType(*this, at::kByte); - n->t_(attr::other, rhs.toTensor()); - return r; + return create(aten::eq, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte); } SymbolicVariable operator!=(at::Scalar rhs) const { - Node * n; - auto r = create(aten::ne, {*this}, 1, &n)[0].typeLikeWithScalarType(*this, at::kByte); - n->t_(attr::other, rhs.toTensor()); - return r; + return create(aten::ne, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte); } SymbolicVariable operator+(const SymbolicVariable rhs) const { - Node * n; - auto r = create(aten::add, {*this, rhs}, 1, &n)[0].typeLike(*this); - n->t_(attr::alpha, at::Scalar(1).toTensor()); - return r; + return create(aten::add, {*this, rhs, insertConstant(1)})[0].typeLike(*this); } SymbolicVariable operator+(at::Scalar rhs) const { - Node * n; - auto r = create(aten::add, {*this}, 1, &n)[0].typeLike(*this); - n->t_(attr::alpha, at::Scalar(1).toTensor()); - n->t_(attr::other, rhs.toTensor()); - return r; + return (*this) + insertConstant(rhs); } SymbolicVariable operator-() const { return create(aten::neg, {*this})[0].typeLike(*this); } SymbolicVariable operator-(const SymbolicVariable rhs) const { - Node *n; - auto r = create(aten::sub, {*this, rhs}, 1, &n)[0].typeLike(*this); - n->t_(attr::alpha, at::Scalar(1).toTensor()); - return r; + return create(aten::sub, {*this, rhs, insertConstant(1)})[0].typeLike(*this); } SymbolicVariable operator/(at::Scalar rhs) const { - Node *n; - auto r = create(aten::div, {*this}, 1, &n)[0].typeLike(*this); - n->t_(attr::other, rhs.toTensor()); - return r; + return create(aten::div, {*this, insertConstant(rhs)})[0].typeLike(*this); } SymbolicVariable operator%(at::Scalar rhs) const { - Node *n; - auto r = create(aten::remainder, {*this}, 1, &n)[0].typeLike(*this); - n->t_(attr::other, rhs.toTensor()); - return r; + return create(aten::remainder, {*this, insertConstant(rhs)})[0].typeLike(*this); } SymbolicVariable mm(const SymbolicVariable rhs) const { - auto r = create(t("mm"), {*this, rhs})[0]; - return r; + return create(t("mm"), {*this, rhs})[0]; } SymbolicVariable t() const { - auto r = create(t("t"), {*this})[0]; - return r; + return create(t("t"), {*this})[0]; } SymbolicVariable sigmoid() const { return create(aten::sigmoid, {*this})[0].typeLike(*this); @@ -147,88 +109,73 @@ struct SymbolicVariable { SymbolicVariable tanh() const { return create(aten::tanh, {*this})[0].typeLike(*this); } - std::vector chunk(int32_t chunks, uint32_t dim) const { - Node * n; - auto r = create(t("chunk"), { *this }, chunks, &n); - n->i_(a("chunks"), chunks) - ->i_(a("dim"), dim); - return r; + std::vector chunk(int64_t chunks, int dim) const { + return create(t("chunk"), { *this , insertConstant(chunks), insertConstant(dim) }, chunks); } SymbolicVariable type_as(const SymbolicVariable rhs) const { return create(aten::type_as, {*this, rhs})[0].typeLikeWithRhsScalarType(*this, rhs); } SymbolicVariable narrow(int dim, int64_t start, int64_t length) const { - Node * n; - auto r = create(t("narrow"), { *this }, 1, &n)[0]; - n->i_(a("dim"), dim) - ->i_(a("start"), start) - ->i_(a("length"), length); - return r; + return create(t("narrow"), { *this, insertConstant(dim), insertConstant(start), insertConstant(length) }, 1)[0]; } static SymbolicVariable cat(ArrayRef inputs, Value* dim) { - Node* n; std::vector all_inputs = inputs; all_inputs.push_back(dim); - auto r = create(aten::cat, all_inputs, 1, &n)[0]; - return r; + return create(aten::cat, all_inputs)[0]; } - static SymbolicVariable cat(ArrayRef inputs, int32_t dim) { - Node* n; - auto r = create(aten::cat, inputs, 1, &n)[0]; - n->i_(attr::dim, dim); - return r; + static SymbolicVariable cat(ArrayRef inputs, int dim) { + JIT_ASSERT(inputs.size() > 0); + return SymbolicVariable::cat(inputs, inputs[0].insertConstant(dim)); } - static SymbolicVariable stack(ArrayRef inputs, int32_t dim) { - Node* n; - auto r = create(aten::stack, inputs, 1, &n)[0]; - n->i_(attr::dim, dim); - return r; + static SymbolicVariable stack(ArrayRef inputs, Value* dim) { + std::vector all_inputs = inputs; + all_inputs.push_back(dim); + return create(aten::stack, all_inputs)[0]; + } + static SymbolicVariable stack(ArrayRef inputs, int dim) { + JIT_ASSERT(inputs.size() > 0); + return SymbolicVariable::stack(inputs, inputs[0].insertConstant(dim)); } SymbolicVariable sum() const { - auto r = create(t("sum"), {*this})[0]; - return r; + return create(t("sum"), {*this})[0]; } SymbolicVariable sum(int dim, bool keepdim) const { - Node * n; - auto r = create(t("sum"), {*this}, 1, &n)[0]; - n->is_(a("dim"), {dim}) - ->i_(a("keepdim"), keepdim); - return r; + return create(t("sum"), {*this, insertConstant(at::IntList{dim}), insertConstant(keepdim)})[0]; } SymbolicVariable squeeze(Value* dim) const { - Node * n; - auto r = create(t("squeeze"), {*this, dim}, 1, &n)[0]; - return r; + return create(t("squeeze"), {*this, dim})[0]; } SymbolicVariable squeeze(int dim) const { - Node * n; - auto r = create(t("squeeze"), {*this}, 1, &n)[0]; - n->i_(a("dim"), dim); - return r; + return squeeze(insertConstant(dim)); + } + SymbolicVariable unsqueeze(Value* dim) const { + return create(t("unsqueeze"), {*this, dim})[0]; } SymbolicVariable unsqueeze(int dim) const { - Node * n; - auto r = create(t("unsqueeze"), {*this}, 1, &n)[0]; - n->i_(a("dim"), dim); - return r; + return unsqueeze(insertConstant(dim)); + } + SymbolicVariable view(Value* sizes) const { + return create(aten::view, {*this, sizes})[0]; } SymbolicVariable view(std::vector sizes) const { - Node *n; - auto r = create(aten::view, {*this}, 1, &n)[0]; - n->is_(a("size"), std::move(sizes)); - return r; + return view(insertConstant(sizes)); + } + SymbolicVariable reshape(Value* sizes) const { + return create(aten::reshape, {*this, sizes})[0]; + } + SymbolicVariable reshape(std::vector sizes) const { + return reshape(insertConstant(sizes)); } SymbolicVariable addmm(SymbolicVariable mat1, SymbolicVariable mat2) const { - Node *n; - auto r = create(aten::addmm, {*this, mat1, mat2}, 1, &n)[0]; - n->t_(a("alpha"), at::CPU(at::kFloat).scalarTensor(1.0)); - n->t_(a("beta"), at::CPU(at::kFloat).scalarTensor(1.0)); - return r; + return create(aten::addmm, {*this, mat1, mat2, insertConstant(1.0), insertConstant(1.0)})[0]; } Value * value() const { return v; } private: + Value * insertConstant(IValue value) const { + return jit::insertConstant(*v->owningGraph(), value); + } SymbolicVariable typeLike(SymbolicVariable other) { if (auto other_type = other.v->type()->cast()) v->setType(other_type->contiguous()); diff --git a/torch/csrc/jit/test_jit.cpp b/torch/csrc/jit/test_jit.cpp index ecb8c9b3779816..8c9763f88353e5 100644 --- a/torch/csrc/jit/test_jit.cpp +++ b/torch/csrc/jit/test_jit.cpp @@ -641,7 +641,7 @@ std::string toString(std::shared_ptr& graph) { void testDifferentiate(std::ostream & out) { auto graph = std::make_shared(); at::ScalarType s = at::ScalarType::Float; - auto type = std::shared_ptr(new TensorType(s, -1, {2, 3, 4}, {12, 4, 1})); + auto type = TensorType::create(s, -1, {2, 3, 4}, {12, 4, 1}); // Build up a fake graph auto a = SymbolicVariable::asNewInput(*graph, type); @@ -668,7 +668,7 @@ void testDifferentiate(std::ostream & out) { void testDifferentiateWithRequiresGrad(std::ostream & out) { auto graph = std::make_shared(); at::ScalarType s = at::ScalarType::Float; - auto type = std::shared_ptr(new TensorType(s, -1, {2, 3, 4}, {12, 4, 1})); + auto type = TensorType::create(s, -1, {2, 3, 4}, {12, 4, 1}); // Build up a fake graph auto a = SymbolicVariable::asNewInput(*graph, type); diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp index 5c998e3fc690bf..aec6eb4ddc9447 100644 --- a/torch/csrc/jit/tracer.cpp +++ b/torch/csrc/jit/tracer.cpp @@ -19,6 +19,51 @@ namespace torch { namespace jit { namespace tracer { //////////////////////////////////////////////////////////////////////////////// namespace detail { +template +void genericAddInput(Node *n, T value) { + n->addInput(insertConstant(*n->owningGraph(), value)); +} + +void badArgType() { + throw std::runtime_error("Found an unsupported argument type in the JIT tracer. File a bug report."); +} + + +void addInputs(Node *n, const char * name, int64_t value) { genericAddInput(n, value); } +void addInputs(Node *n, const char * name, bool value) { genericAddInput(n, value); } +void addInputs(Node *n, const char * name, double value) { genericAddInput(n, value); } +void addInputs(Node *n, const char * name, const at::Scalar& value) { genericAddInput(n, value); } +void addInputs(Node *n, const char * name, const at::Tensor& value) { n->addInput(getValueTrace(value)); } +void addInputs(Node *n, const char * name, const std::string& value) { badArgType(); } +void addInputs(Node *n, const char * name, const at::SparseTensorRef& value) { badArgType(); } + +void addInputs(Node *n, const char * name, at::TensorList value) { + for (auto & t : value) { + n->addInput(getValueTrace(t)); + } +} + +void addInputs(Node *n, const char * name, at::IntList value) { + using ArgumentStash = jit::tracer::ArgumentStash; + std::vector info = ArgumentStash::hasIntList(name) ? + ArgumentStash::popIntList(name) : + ArgumentStash::IntListTrace(value.size()); + + auto& g = getTracingState()->graph; + for (size_t i = 0; i < info.size(); ++i) { + if (info[i] != nullptr) continue; + info[i] = insertConstant(*g, value[i]); + } + for (jit::Value* v : info) { + if (*v->type() != *jit::IntType::get()) { + throw std::runtime_error( + "Type mismatch in setposattr for IntList. Check that your program " + "is valid without tracing, and please file a bug report if it is."); + } + } + n->addInput(g->insertNode(g->createList(jit::IntType::get(), info))->output()); +} + thread_local std::shared_ptr tracing_state; } // namespace detail @@ -36,13 +81,6 @@ TracingState::TracingState() TracingState::~TracingState() = default; -PreTraceInfo preRecordTrace(Symbol op, - at::ArrayRef inputs) { - return makePreTraceInfo(inputs, [&op](const std::shared_ptr& state, Graph& graph) { - return graph.create(op, 0 /* initial outputs */); - }); -} - void postRecordTrace(const PreTraceInfo& info, at::ArrayRef outputs) { for (size_t i = 0; i < outputs.size(); i++) { diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h index bde3edf52221e1..c9780119a385a0 100644 --- a/torch/csrc/jit/tracer.h +++ b/torch/csrc/jit/tracer.h @@ -193,36 +193,58 @@ struct PreTraceInfo { Node *n; }; -TORCH_API PreTraceInfo preRecordTrace(Symbol op, at::ArrayRef inputs); -TORCH_API void postRecordTrace(const PreTraceInfo& info, at::ArrayRef outputs); TORCH_API void recordSourceLocation(Node* n); TORCH_API void setRecordSourceLocation(void (*v)(Node*)); -// We must record the nodes of inputs before we actually carry out -// the operation, because an inplace operation may destroy the information -// we're interested in. See #4480. -template -PreTraceInfo makePreTraceInfo(at::ArrayRef inputs, F ctor) { +namespace detail { + +// NB: those serve both as an intermediate steps in addInputs below, +// as well as the overloads that terminate template recursion +void addInputs(Node *n, const char * name, int64_t value); +void addInputs(Node *n, const char * name, bool value); +void addInputs(Node *n, const char * name, double value); +void addInputs(Node *n, const char * name, const at::Scalar& value); +void addInputs(Node *n, const char * name, const at::Tensor& value); +void addInputs(Node *n, const char * name, at::IntList value); +void addInputs(Node *n, const char * name, at::TensorList value); +void addInputs(Node *n, const char * name, const std::string& value); +void addInputs(Node *n, const char * name, const at::SparseTensorRef& value); + +template +void addInputs(Node *n, const char * name, std::array value) { + throw std::runtime_error("Found an unsupported argument type in the JIT tracer. File a bug report."); +} + +template +void addInputs(Node *n, const char * arg_name, T arg, const char * next_arg_name, Args... args) { + addInputs(n, arg_name, arg); + addInputs(n, next_arg_name, args...); +} + +} // namespace detail + +// NB: if you change this function, you might want to take a look at +// preRecordPythonTrace from python_tracer.cpp +template +PreTraceInfo preRecordTrace(Symbol op, Args... inputs) { PreTraceInfo info; auto & state = getTracingState(); auto & graph = state->graph; - Node *n = ctor(state, *graph); + Node * n = info.n = graph->create(op, /*outputs=*/0); recordSourceLocation(n); - for (const Variable & input : inputs) { - n->addInput(getValueTrace(input)); - } + detail::addInputs(n, inputs...); // NB: Order matters. This must append after inputs but before outputs. graph->appendNode(n); - info.n = n; - return info; } +TORCH_API void postRecordTrace(const PreTraceInfo& info, at::ArrayRef outputs); + TORCH_API autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim); }}} // namespace torch::jit::tracer diff --git a/torch/csrc/jit/type.cpp b/torch/csrc/jit/type.cpp index b657d4e935f17a..bf28588ad7eca6 100644 --- a/torch/csrc/jit/type.cpp +++ b/torch/csrc/jit/type.cpp @@ -45,29 +45,29 @@ std::ostream& operator<<(std::ostream & out, const Type & t) { } TypePtr DynamicType::get() { - static auto value = std::make_shared(); + static auto value = DynamicType::create(); return value; } TypePtr NumberType::get() { - static auto value = std::make_shared(); + static auto value = NumberType::create(); return value; } TypePtr IntType::get() { - static auto value = std::make_shared(); + static auto value = IntType::create(); return value; } TypePtr FloatType::get() { - static auto value = std::make_shared(); + static auto value = FloatType::create(); return value; } TypePtr ListType::ofTensors() { - static auto value = std::make_shared(DynamicType::get()); + static auto value = ListType::create(DynamicType::get()); return value; } TypePtr ListType::ofInts() { - static auto value = std::make_shared(IntType::get()); + static auto value = ListType::create(IntType::get()); return value; } diff --git a/torch/csrc/jit/type.h b/torch/csrc/jit/type.h index 177833f23d938e..dc2cea1fa50b94 100644 --- a/torch/csrc/jit/type.h +++ b/torch/csrc/jit/type.h @@ -31,7 +31,7 @@ struct Type; using TypePtr = std::shared_ptr; -struct TORCH_API Type { +struct TORCH_API Type : std::enable_shared_from_this { private: TypeKind kind_; @@ -44,8 +44,8 @@ struct TORCH_API Type { // subtyping relation. By default, we return true for the case // when the type is exactly equal - virtual bool isSubtypeOf(const Type& rhs) const { - return *this == rhs; + virtual bool isSubtypeOf(const TypePtr rhs) const { + return *this == *rhs; } // user-friendly form of the type, separate from // operator<< which is verbose and unambiguous @@ -58,26 +58,26 @@ struct TORCH_API Type { // Dynamically cast this object to the subclass indicated by the // template variable, returning nullptr if the cast is invalid.. template - T* cast() { + std::shared_ptr cast() { if (T::Kind == kind()) - return static_cast(this); + return std::static_pointer_cast(shared_from_this()); return nullptr; } template - const T* cast() const { + std::shared_ptr cast() const { if (T::Kind == kind()) - return static_cast(this); + return std::static_pointer_cast(shared_from_this()); return nullptr; } template - T* expect() { + std::shared_ptr expect() { JIT_ASSERT(T::Kind == kind()); - return static_cast(this); + return std::static_pointer_cast(shared_from_this()); } template - const T* expect() const { + std::shared_ptr expect() const { JIT_ASSERT(T::Kind == kind()); - return static_cast(this); + return std::static_pointer_cast(shared_from_this()); } virtual ~Type() {} }; @@ -86,10 +86,15 @@ inline bool operator!=(const Type & lhs, const Type & rhs) { return !(lhs == rhs); } +struct DynamicType; +using DynamicTypePtr = std::shared_ptr; // This node represents a single Tensor value, with an unknown shape. struct TORCH_API DynamicType : public Type { - DynamicType() - : Type(TypeKind::DynamicType) {} + template + static DynamicTypePtr create( T&& ... all ) { + return DynamicTypePtr(new DynamicType( std::forward(all)... )); + } + bool operator==(const Type& rhs) const override { return rhs.kind() == kind(); } @@ -99,6 +104,9 @@ struct TORCH_API DynamicType : public Type { static const TypeKind Kind = TypeKind::DynamicType; // global singleton static TypePtr get(); +private: + DynamicType() + : Type(TypeKind::DynamicType) {} }; struct TensorType; @@ -106,21 +114,18 @@ using TensorTypePtr = std::shared_ptr; // This node represents a single Tensor value with a specific size struct TORCH_API TensorType : public Type { friend struct Type; - TensorType(const at::Tensor& tensor) - : Type(TypeKind::TensorType) - , scalar_type_(tensor.type().scalarType()) - , device_(tensor.type().is_cuda() ? tensor.get_device() : -1) - , sizes_(tensor.sizes()) - , strides_(tensor.strides()) {} - TensorType(at::ScalarType scalar_type, int device, at::IntList sizes) - : TensorType(scalar_type, device, sizes, TensorType::contiguousStridesOf(sizes)) {} - TensorType(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides) - : Type(TypeKind::TensorType) - , scalar_type_(scalar_type) - , device_(device) - , sizes_(sizes) - , strides_(strides) - {} + template + static TensorTypePtr create( T&& ... all ) { + return TensorTypePtr(new TensorType( std::forward(all)... )); + } + + // overloaded create variadic template argument as it could not distinguish initializer list + static TensorTypePtr create(at::ScalarType scalar_type, int device, at::IntList sizes) { + return TensorTypePtr(new TensorType(scalar_type, device, sizes)); + } + static TensorTypePtr create(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides) { + return TensorTypePtr(new TensorType(scalar_type, device, sizes, strides)); + } static const TypeKind Kind = TypeKind::TensorType; @@ -130,7 +135,7 @@ struct TORCH_API TensorType : public Type { const std::vector& strides() const { return strides_; } TypePtr withSizesStrides(at::IntList sizes, at::IntList strides) const { - return std::make_shared(scalar_type_, device_, sizes, strides); + return TensorType::create(scalar_type_, device_, sizes, strides); } TypePtr withSizes(at::IntList sizes) const { @@ -138,13 +143,13 @@ struct TORCH_API TensorType : public Type { } TensorTypePtr contiguous() const { - auto t = std::make_shared(*this); + auto t = TensorType::create(*this); t->strides_ = TensorType::contiguousStridesOf(sizes_); return t; } TensorTypePtr toScalarType(at::ScalarType type){ - auto t = std::make_shared(*this); + auto t = TensorType::create(*this); t->scalar_type_ = type; return t; } @@ -158,8 +163,8 @@ struct TORCH_API TensorType : public Type { strides() == rt->strides() && device() == rt->device(); } - bool isSubtypeOf(const Type& rhs) const override { - return *this == rhs || rhs.kind() == TypeKind::DynamicType; + bool isSubtypeOf(const TypePtr rhs) const override { + return *this == *rhs || rhs->kind() == TypeKind::DynamicType; } std::string str() const override { // str is used for user-facing error messages, where we @@ -176,6 +181,21 @@ struct TORCH_API TensorType : public Type { static TypePtr fromNumberType(TypePtr typ); private: + TensorType(const at::Tensor& tensor) + : Type(TypeKind::TensorType) + , scalar_type_(tensor.type().scalarType()) + , device_(tensor.type().is_cuda() ? tensor.get_device() : -1) + , sizes_(tensor.sizes()) + , strides_(tensor.strides()) {} + TensorType(at::ScalarType scalar_type, int device, at::IntList sizes) + : TensorType(scalar_type, device, sizes, TensorType::contiguousStridesOf(sizes)) {} + TensorType(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides) + : Type(TypeKind::TensorType) + , scalar_type_(scalar_type) + , device_(device) + , sizes_(sizes) + , strides_(strides) + {} static std::vector contiguousStridesOf(at::IntList sizes) { std::vector strides(sizes.size()); if(sizes.size() == 0) // zero-dim case @@ -192,11 +212,15 @@ struct TORCH_API TensorType : public Type { std::vector strides_; }; +struct ListType; +using ListTypePtr = std::shared_ptr; + struct TORCH_API ListType : public Type { friend struct Type; - static const TypeKind Kind = TypeKind::ListType; - ListType(TypePtr elem) - : Type(TypeKind::ListType), elem(elem) {} + template + static ListTypePtr create( T&& ... all ) { + return ListTypePtr(new ListType( std::forward(all)... )); + } bool operator==(const Type& rhs) const override { if(auto rhs_ = rhs.cast()) { return *getElementType() == *rhs_->getElementType(); @@ -215,35 +239,41 @@ struct TORCH_API ListType : public Type { static TypePtr ofTensors(); static TypePtr ofInts(); private: + ListType(TypePtr elem) + : Type(TypeKind::ListType), elem(elem) {} + static const TypeKind Kind = TypeKind::ListType; TypePtr elem; }; +struct TupleType; +using TupleTypePtr = std::shared_ptr; + struct TORCH_API TupleType : public Type { friend struct Type; - TupleType(std::vector elements_) - : Type(TypeKind::TupleType) - , elements_(std::move(elements_)) {} - static const TypeKind Kind = TypeKind::TupleType; + template + static TupleTypePtr create( T&& ... all ) { + return TupleTypePtr(new TupleType( std::forward(all)... )); + } at::ArrayRef elements() const { return elements_; } bool operator==(const Type& rhs) const override { - return compare(rhs, [](const Type& a, const Type& b) { - return a == b; + return compare(rhs, [](const TypePtr a, const TypePtr b) { + return *a == *b; }); } - bool isSubtypeOf(const Type& rhs) const override { + bool isSubtypeOf(const TypePtr rhs) const override { // e.g. (Tensor, Tensor, Tensor) <: List[Tensor] - if(auto lt = rhs.cast()) { + if(auto lt = rhs->cast()) { for(auto e : elements()) { - if(!e->isSubtypeOf(*lt->getElementType())) + if(!e->isSubtypeOf(lt->getElementType())) return false; } return true; } // co-variant rules for tuples - return compare(rhs, [](const Type& a, const Type&b) { - return a.isSubtypeOf(b); + return compare(*rhs, [](const TypePtr a, const TypePtr b) { + return a->isSubtypeOf(b); }); } std::string str() const override { @@ -258,7 +288,12 @@ struct TORCH_API TupleType : public Type { return ss.str(); } private: - bool compare(const Type& rhs, std::function fn) const { + TupleType(std::vector elements_) + : Type(TypeKind::TupleType) + , elements_(std::move(elements_)) {} + static const TypeKind Kind = TypeKind::TupleType; + + bool compare(const Type& rhs, std::function fn) const { if(rhs.kind() != kind()) return false; const auto & l_elements = elements(); @@ -266,7 +301,7 @@ struct TORCH_API TupleType : public Type { if(l_elements.size() != r_elements.size()) return false; for(size_t i = 0; i < l_elements.size(); ++i) { - if(!fn(*l_elements[i], *r_elements[i])) + if(!fn(l_elements[i], r_elements[i])) return false; } return true; @@ -274,10 +309,14 @@ struct TORCH_API TupleType : public Type { std::vector elements_; }; +struct NumberType; +using NumberTypePtr = std::shared_ptr; // This node represents a Python number value struct TORCH_API NumberType : public Type { - NumberType() - : Type(TypeKind::NumberType) {} + template + static NumberTypePtr create( T&& ... all ) { + return NumberTypePtr(new NumberType( std::forward(all)... )); + } bool operator==(const Type& rhs) const override { return rhs.kind() == kind(); } @@ -287,42 +326,59 @@ struct TORCH_API NumberType : public Type { static const TypeKind Kind = TypeKind::NumberType; // global singleton static TypePtr get(); +private: + NumberType() + : Type(TypeKind::NumberType) {} }; +struct FloatType; +using FloatTypePtr = std::shared_ptr; // This node represents a Python float number value struct TORCH_API FloatType : public Type { - FloatType() - : Type(TypeKind::FloatType) {} + template + static FloatTypePtr create( T&& ... all ) { + return FloatTypePtr(new FloatType( std::forward(all)... )); + } bool operator==(const Type& rhs) const override { return rhs.kind() == kind(); } std::string str() const override { return "float"; } - bool isSubtypeOf(const Type& rhs) const override { - return *this == rhs || rhs.kind() == TypeKind::NumberType; + bool isSubtypeOf(const TypePtr rhs) const override { + return *this == *rhs || rhs->kind() == TypeKind::NumberType; } static const TypeKind Kind = TypeKind::FloatType; // global singleton static TypePtr get(); +private: + FloatType() + : Type(TypeKind::FloatType) {} }; +struct IntType; +using IntTypePtr = std::shared_ptr; // This node represents a Python int number value struct TORCH_API IntType : public Type { - IntType() - : Type(TypeKind::IntType) {} + template + static IntTypePtr create( T&& ... all ) { + return IntTypePtr(new IntType( std::forward(all)... )); + } bool operator==(const Type& rhs) const override { return rhs.kind() == kind(); } std::string str() const override { return "int"; } - bool isSubtypeOf(const Type& rhs) const override { - return *this == rhs || rhs.kind() == TypeKind::NumberType; + bool isSubtypeOf(const TypePtr rhs) const override { + return *this == *rhs || rhs->kind() == TypeKind::NumberType; } static const TypeKind Kind = TypeKind::IntType; // global singleton static TypePtr get(); +private: + IntType() + : Type(TypeKind::IntType) {} }; @@ -331,10 +387,10 @@ TORCH_API std::ostream& operator<<(std::ostream & out, const Type & t); // e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...) inline TypePtr unshapedType(const TypePtr& type) { - if(TupleType* t = type->cast()) { - return std::make_shared(fmap(t->elements(), unshapedType)); - } else if(ListType* t = type->cast()) { - return std::make_shared(unshapedType(t->getElementType())); + if(TupleTypePtr t = type->cast()) { + return TupleType::create(fmap(t->elements(), unshapedType)); + } else if(ListTypePtr t = type->cast()) { + return ListType::create(unshapedType(t->getElementType())); } else if(type->kind() == TypeKind::TensorType) { return DynamicType::get(); } else { @@ -343,13 +399,11 @@ inline TypePtr unshapedType(const TypePtr& type) { } inline TypePtr TensorType::fromNumberType(TypePtr typ) { - JIT_ASSERT(typ->isSubtypeOf(*NumberType::get())); - if(typ->isSubtypeOf(*IntType::get())) { - TensorType tt(at::kLong, -1, {}); - return std::make_shared(std::move(tt)); - } else if(typ->isSubtypeOf(*FloatType::get())) { - TensorType tt(at::kFloat, -1, {}); - return std::make_shared(std::move(tt)); + JIT_ASSERT(typ->isSubtypeOf(NumberType::get())); + if(typ->isSubtypeOf(IntType::get())) { + return TensorType::create(at::kLong, -1, {}); + } else if(typ->isSubtypeOf(FloatType::get())) { + return TensorType::create(at::kFloat, -1, {}); } AT_ERROR("unknown number type", typ->str()); } diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index d5f08aa10273fb..454151afed8201 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -411,7 +411,7 @@ class ParameterDict(Module): class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() - self.choices = nn.ParameterDict({ + self.params = nn.ParameterDict({ 'left': nn.Parameter(torch.randn(5, 10)), 'right': nn.Parameter(torch.randn(5, 10)) }) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index a00ff3dd9c268c..61c93a7e810f98 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -642,6 +642,10 @@ def _load_from_state_dict(self, state_dict, prefix, metadata, strict, missing_ke if key in state_dict: input_param = state_dict[key] + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + if input_param.shape != param.shape: # local shape should match the one in checkpoint error_msgs.append('size mismatch for {}: copying a param of {} from checkpoint, ' diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 4ac6ca7f887195..99fd4dd1b25a7b 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -11,7 +11,7 @@ import torch.onnx.utils from collections import Iterable -from functools import partial +from functools import partial, wraps import itertools # EDITING THIS FILE? READ THIS FIRST! @@ -32,6 +32,59 @@ # --------------------------------------------------------------------- +def _parse_arg(value, desc): + if desc == 'v' or not _is_value(value): + return value + + if value.node().kind() != 'onnx::Constant': + raise RuntimeError("ONNX symbolic expected a constant value in the trace") + tval = value.node()['value'] + if desc == 'i': + return int(tval) + elif desc == 'f': + return float(tval) + elif desc == 't': + return tval + elif desc == 'is': + return [int(v) for v in tval] + else: + raise RuntimeError("Casting constants to `{}` is not implemented".format(desc)) + + +def _maybe_get_const(value, desc): + if _is_value(value) and value.node().kind() == 'onnx::Constant': + return _parse_arg(value, desc) + return value + + +def _maybe_get_scalar(value): + value_t = _maybe_get_const(value, 't') + if isinstance(value_t, torch.Tensor) and value_t.shape == (): + return value_t + return value + + +def _get_const(value, desc, arg_name): + if _is_value(value) and value.node().kind() != 'onnx::Constant': + raise RuntimeError("ONNX symbolic expected a constant value of the {} argument".format(arg_name)) + return _parse_arg(value, desc) + + +def parse_args(*arg_descriptors): + def decorator(fn): + def wrapper(g, *args): + assert len(arg_descriptors) == len(args) + args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] + return fn(g, *args) + # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround + try: + wrapper = wraps(fn)(wrapper) + except Exception: + pass + return wrapper + return decorator + + def _scalar(x): """Convert a scalar tensor into a Python value.""" assert x.numel() == 1 @@ -137,27 +190,33 @@ def unused(g): return g.op("prim::Undefined") +@parse_args('v', 'v', 't') def add(g, self, other, alpha): if _scalar(alpha) != 1: return _unimplemented("add", "alpha != 1") # See Note [Pointwise by scalar] + other = _maybe_get_scalar(other) return g.op("Add", self, _if_scalar_type_as(other, self), **_broadcast_if_scalar(other)) +@parse_args('v', 'v', 't') def sub(g, self, other, alpha): if _scalar(alpha) != 1: return _unimplemented("sub", "alpha != 1") # See Note [Pointwise by scalar] + other = _maybe_get_scalar(other) return g.op("Sub", self, _if_scalar_type_as(other, self), **_broadcast_if_scalar(other)) def mul(g, self, other): # See Note [Pointwise by scalar] + other = _maybe_get_scalar(other) return g.op("Mul", self, _if_scalar_type_as(other, self), **_broadcast_if_scalar(other)) def div(g, self, other): # See Note [Pointwise by scalar] + other = _maybe_get_scalar(other) return g.op("Div", self, _if_scalar_type_as(other, self), **_broadcast_if_scalar(other)) @@ -166,9 +225,9 @@ def reciprocal(g, self): # This syntax is Python 2 portable -def cat(g, *tensors, **kwargs): - dim = kwargs.pop("dim") - assert not kwargs +def cat(g, *args): + dim = _get_const(args[-1], 'i', 'dim') + tensors = args[:-1] return g.op("Concat", *tensors, axis_i=dim) @@ -188,6 +247,7 @@ def matmul(g, self, other): return g.op("MatMul", self, other) +@parse_args('v', 'v', 'v', 't', 't') def addmm(g, self, mat1, mat2, beta, alpha): return g.op("Gemm", mat1, mat2, self, beta_f=_scalar(beta), alpha_f=_scalar(alpha)) @@ -211,12 +271,13 @@ def sigmoid(g, self): def _reduce_op_symbolic(onnx_op_name): def symbolic(g, self, dim=None, keepdim=None): params = {} - if dim is not None: - if isinstance(dim, numbers.Number): - dim = [dim] - params['axes_i'] = dim - params['keepdims_i'] = int(bool(keepdim)) - return g.op(onnx_op_name, self, **params) + if dim is None: + # all-reduce path + return g.op(onnx_op_name, self, keepdims_i=0) + else: + # dim-reduce path + dim, keepdim = _get_const(dim, 'i', 'dim'), _get_const(keepdim, 'i', 'keepdim') + return g.op(onnx_op_name, self, axes_i=[dim], keepdims_i=keepdim) return symbolic mean = _reduce_op_symbolic('ReduceMean') @@ -224,6 +285,7 @@ def symbolic(g, self, dim=None, keepdim=None): prod = _reduce_op_symbolic('ReduceProd') +@parse_args('v', 'i') def cumsum(g, input, dim): return g.op("ATen", input, operator_s="cumsum", dim_i=dim) @@ -241,6 +303,7 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): return g.op("Gather", weight, indices) +@parse_args('v', 'v', 'v', 'i', 'i', 'i') def embedding_bag(g, embedding_matrix, indices, @@ -260,14 +323,11 @@ def embedding_bag(g, def size(g, self, dim): - if _is_value(dim): - if dim.node().kind() != 'onnx::Constant': - raise RuntimeError("ONNX export only supports constant dim values in .size()") - dim = int(dim.node().t('value')) full_shape = g.op("Shape", self) - return select(g, full_shape, dim=0, index=dim) + return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim) +@parse_args('v', 'i', 'i') def transpose(g, self, dim0, dim1): if dim0 == dim1: # micro-optimization return self @@ -278,6 +338,7 @@ def transpose(g, self, dim0, dim1): return g.op("Transpose", self, perm_i=axes) +@parse_args('v', 'is') def permute(g, self, dims): if dims == list(range(0, len(dims))): return self @@ -285,6 +346,7 @@ def permute(g, self, dims): def view(g, self, size): + size = _maybe_get_const(size, 'is') if _is_value(size): shape = size else: @@ -296,16 +358,12 @@ def view(g, self, size): return g.op("Reshape", self, shape) -def stack(g, *tensors, **kwargs): - dim = kwargs.pop('dim') - if kwargs: - raise RuntimeError("Unexpected kwargs: " + ','.join(kwargs.keys())) - if len(tensors) < 1: - raise RuntimeError("Expected at least one argument to stack node") - unsqueezed = [g.op("Unsqueeze", t, axes_i=[dim]) for t in tensors] - return g.op("Concat", *unsqueezed, axis_i=dim) +def stack(g, *args): + unsqueezed = [g.op("Unsqueeze", t, axes_i=[dim]) for t in args[:-1]] + [args[-1]] + return concat(g, *unsqueezed) +@parse_args('v', 'i', 'i') def split(g, self, split_size, dim): size = self.type().sizes()[dim] splits = [split_size] * (size // split_size) @@ -319,11 +377,13 @@ def split(g, self, split_size, dim): # less sensitive to changes in input size. # TODO: Once we have proper scoping, stop reimplementing chunk, delete this # method, and use the desugared version +@parse_args('v', 'i', 'i') def chunk(g, self, chunks, dim): split_size = (self.type().sizes()[dim] + chunks - 1) // chunks return split(g, self, split_size, dim) +@parse_args('v', 'i', 'i') def select(g, self, dim, index): slice_node = g.op("Slice", self, axes_i=[dim], starts_i=[index], ends_i=[index + 1]) return g.op("Squeeze", slice_node, axes_i=[dim]) @@ -336,7 +396,7 @@ def squeeze(g, self, dim=None): if size == 1: dims.append(i) else: - dims = [dim] + dims = [_get_const(dim, 'i', 'dim')] return g.op("Squeeze", self, axes_i=dims) @@ -348,6 +408,7 @@ def relu(g, input): return g.op("Relu", input) +@parse_args('v', 't', 't') def threshold(g, self, threshold, value): # See Note [Export inplace] if _scalar(threshold) != 0: @@ -358,11 +419,13 @@ def threshold(g, self, threshold, value): def leaky_relu(g, input, negative_slope, inplace=False): + negative_slope = _get_const(negative_slope, 't', 'negative_slope') # See Note [Export inplace] # TODO: Talk to ONNX about unconditional cast of scalar to float return g.op("LeakyRelu", input, alpha_f=_scalar(negative_slope)) +@parse_args('v', 'i') def glu(g, input, dim): assert input.type().sizes()[dim] % 2 == 0 @@ -370,7 +433,8 @@ def glu(g, input, dim): return g.op('Mul', first, g.op('Sigmoid', second)) -def softmax(g, input, dim=None): +@parse_args('v', 'i') +def softmax(g, input, dim): # Softmax does normalization at vector level. # PyTorch and ONNX use different strategies to split the input tensor into vectors. # Thus dim and axis have different meanings. @@ -394,12 +458,14 @@ def softmax(g, input, dim=None): return g.op('Softmax', input, axis_i=dim) +@parse_args('v', 't', 'v') def softplus(g, self, beta, threshold): if beta != 1: return _unimplemented("beta", "has to be 1") return g.op('Softplus', self) +@parse_args('v', 'is', 'is', 'is', 'is', 'i') def max_pool1d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode): if ceil_mode: return _unimplemented("max_pool1d_with_indices", "ceil_mode") @@ -414,6 +480,7 @@ def max_pool1d_with_indices(g, input, kernel_size, stride, padding, dilation, ce return r, None +@parse_args('v', 'is', 'is', 'is', 'is', 'i') def max_pool2d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode): if ceil_mode: return _unimplemented("max_pool2d_with_indices", "ceil_mode") @@ -428,6 +495,7 @@ def max_pool2d_with_indices(g, input, kernel_size, stride, padding, dilation, ce return r, None +@parse_args('v', 'is', 'is', 'is', 'is', 'i') def max_pool3d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode): if ceil_mode: return _unimplemented("max_pool3d_with_indices", "ceil_mode") @@ -443,6 +511,7 @@ def max_pool3d_with_indices(g, input, kernel_size, stride, padding, dilation, ce def _avg_pool(name, tuple_fn): + @parse_args('v', 'is', 'is', 'is', 'i', 'i') def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad): if ceil_mode: return _unimplemented("avg_pool2d", "ceil_mode") @@ -469,6 +538,7 @@ def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include avg_pool3d = _avg_pool('avg_pool3d', _triple) +@parse_args('v', 'is') def reflection_pad(g, input, padding): from torch.autograd._functions.utils import prepare_onnx_paddings mode = "reflect" @@ -476,6 +546,7 @@ def reflection_pad(g, input, padding): return g.op("Pad", input, pads_i=paddings, mode_s=mode) +@parse_args('v', 'is') def replication_pad(g, input, padding): from torch.autograd._functions.utils import prepare_onnx_paddings mode = "edge" @@ -491,6 +562,7 @@ def replication_pad(g, input, padding): replication_pad3d = replication_pad +@parse_args('v', 'is') def upsample_nearest2d(g, input, output_size): return g.op("Upsample", input, height_scale_f=float(output_size[-2]) / input.type().sizes()[-2], @@ -498,6 +570,7 @@ def upsample_nearest2d(g, input, output_size): mode_s="nearest") +@parse_args('v', 'is', 'i') def upsample_bilinear2d(g, input, output_size, align_corners): if align_corners: return _unimplemented("upsample_bilinear2d", "align_corners == True") @@ -508,10 +581,12 @@ def upsample_bilinear2d(g, input, output_size, align_corners): def gt(g, input, other): + other = _maybe_get_scalar(other) return g.op("Greater", input, _if_scalar_type_as(other, input), **_broadcast_if_scalar(other)) def lt(g, input, other): + other = _maybe_get_scalar(other) return g.op("Less", input, _if_scalar_type_as(other, input), **_broadcast_if_scalar(other)) @@ -523,10 +598,12 @@ def le(g, input, other): return g.op("Not", gt(g, other, input)) +@parse_args('v', 'i') def log_softmax(g, input, dim=None): return g.op("LogSoftmax", input, axis_i=dim) +@parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is', 'i', 'i', 'i', 'i') def _convolution(g, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled): weight_size = weight.type().sizes() @@ -560,6 +637,7 @@ def _convolution(g, input, weight, bias, stride, padding, dilation, return n +@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i') def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled): input_sizes = input.type().sizes() if len(input_sizes) == 2: @@ -586,10 +664,12 @@ def batch_norm(g, input, weight, bias, running_mean, running_var, training, mome return res +@parse_args('v', 'i', 'i', 'i') def unfold(g, input, dimension, size, step): return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step) +@parse_args('v', 't', 't') def elu(g, input, alpha, scale): if scale and scale != 1.: return _unimplemented("scale", "does not support scale in Elu") @@ -601,12 +681,13 @@ def selu(g, input): return g.op("Selu", input) -def index_select(g, self, index, dim): +@parse_args('v', 'i', 'v') +def index_select(g, self, dim, index): return g.op("Gather", self, index, axis_i=dim) -def index_put(g, *inputs, **kwargs): - return g.op("ATen", *inputs, operator_s='index_put', **kwargs) +def index_put(g, *inputs): + return g.op("ATen", *inputs, operator_s='index_put') def type_as(g, self, other): @@ -631,29 +712,33 @@ def abs(g, self): def pow(g, self, exponent): + exponent = _maybe_get_scalar(exponent) return g.op("Pow", self, _if_scalar_type_as(exponent, self), **_broadcast_if_scalar(exponent)) +@parse_args('v', 'f', 'f') def clamp(g, self, min, max): return g.op("Clip", self, min_f=min, max_f=max) +@parse_args('v', 'f') def clamp_min(g, self, min): return g.op("Clip", self, min_f=min) +@parse_args('v', 'f') def clamp_max(g, self, max): return g.op("Clip", self, max_f=max) # torch.max (same for torch.min) actually has two interfaces smashed together: # torch.max(x, dim, keepdim) and torch.max(x, y) -def max(g, self, *args, **kwargs): - dim = kwargs.get("dim", None) - if dim is None and isinstance(args[0], numbers.Number): - dim = args[0] - if dim is not None: - keepdim = kwargs.get("keepdim", False) +def max(g, self, dim_or_y, keepdim=None): + if keepdim is None: + return g.op("Max", self, dim_or_y) + else: + dim = _get_const(dim_or_y, 'i', 'dim') + keepdim = _get_const(keepdim, 'i', 'keepdim') # TODO: export it as ReduceMax return g.op("ATen", self, @@ -661,27 +746,21 @@ def max(g, self, *args, **kwargs): dim_i=dim, keepdim_i=keepdim, outputs=2) - else: - (other,) = args - return g.op("Max", self, other) -def min(g, self, *args, **kwargs): - dim = kwargs.get("dim", None) - if dim is None and isinstance(args[0], numbers.Number): - dim = args[0] - if dim is not None: - keepdim = kwargs.get("keepdim", False) - # TODO: export it as ReduceMin +def min(g, self, dim_or_y, keepdim=None): + if keepdim is None: + return g.op("Min", self, dim_or_y) + else: + dim = _get_const(dim_or_y, 'i', 'dim') + keepdim = _get_const(keepdim, 'i', 'keepdim') + # TODO: export it as ReduceMax return g.op("ATen", self, operator_s="min", dim_i=dim, keepdim_i=keepdim, outputs=2) - else: - (other,) = args - return g.op("Min", self, other) def eq(g, self, other): @@ -692,6 +771,7 @@ def exp(g, self): return g.op("Exp", self) +@parse_args('v', 't', 'i', 'i') def norm(g, self, p, dim, keepdim): if p == 1: f = _reduce_op_symbolic("ReduceL1") @@ -702,10 +782,12 @@ def norm(g, self, p, dim, keepdim): return f(g, self, dim=dim, keepdim=keepdim) +@parse_args('v', 'v', 'v', 'i') def conv_tbc(g, input, weight, bias, pad): return g.op("ATen", input, weight, bias, operator_s="conv_tbc", pad_i=pad) +@parse_args('v', 'i', 'i') def _unique(g, input, sorted, return_inverse): return g.op("ATen", input, operator_s="_unique", sorted_i=sorted, return_inverse_i=return_inverse, outputs=2) @@ -746,7 +828,7 @@ def _cast_func_template(to_i, g, input, non_blocking): for k, v in cast_pytorch_to_onnx.items(): name = '_cast_{}'.format(k) - globals()[name] = partial(_cast_func_template, v) + globals()[name] = parse_args('v', 'i')(partial(_cast_func_template, v)) def zeros_like(g, input): @@ -755,15 +837,17 @@ def zeros_like(g, input): def full_like(g, input, fill_value): # TODO: a more efficient implementation (ConstantFill?) - return add(g, zeros_like(g, input), fill_value, alpha=torch.tensor(1)) + return add(g, zeros_like(g, input), fill_value, g.op("Constant", value_t=torch.tensor(1))) +@parse_args('v', 'i', 'i', 'i', 'i') def slice(g, self, dim, start, end, step): if step != 1: _unimplemented("slice", "step!=1 is currently not supported") return g.op("Slice", self, axes_i=[dim], starts_i=[start], ends_i=[end]) +@parse_args('v', 'f', 'f') def hardtanh(g, self, min_val, max_val): return g.op("Clip", self, min_f=min_val, max_f=max_val) @@ -772,11 +856,13 @@ def alias(g, self): return self +@parse_args('v', 'i') def unsqueeze(g, self, dim): return g.op("Unsqueeze", self, axes_i=[dim]) -def topk(g, self, k, dim=None, largest=True, sorted=True, out=None): +@parse_args('v', 'i', 'i', 'i', 'i') +def topk(g, self, k, dim, largest, sorted, out=None): if out is not None: _unimplemented("TopK", "Out parameter is not supported for topk") if not largest: @@ -785,6 +871,7 @@ def topk(g, self, k, dim=None, largest=True, sorted=True, out=None): return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2) +@parse_args('v', 'is') def repeat(g, self, repeats): if self.isTensor(): sizes = self.type().sizes() @@ -1041,5 +1128,6 @@ def retrieve_state(x, start, end): return symbolic +@parse_args('v', 'i') def _dim_arange(g, like, dim): return g.op('ATen', like, dim_i=dim, operator_s='_dim_arange') diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 59d567f461789b..7ce8220ff72b45 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -94,10 +94,23 @@ def export(model, args, f, export_params=True, verbose=False, training=False, operator_export_type=operator_export_type) -def _optimize_graph(graph, operator_export_type): +def _list_constant_prop(g, block): + for node in block.nodes(): + for subblock in node.blocks(): + _list_constant_prop(g, subblock) + if node.kind() == "prim::ListConstruct": + input_nodes = [i.node() for i in node.inputs()] + if all(inode.kind() == "prim::Constant" and inode.kindOf("value") == "i" for inode in input_nodes): + input_values = [inode['value'] for inode in input_nodes] + const_node = g.create("prim::Constant") + const_node.insertBefore(node) + const_node.is_("value", input_values) + const_node.output().setType(torch._C.ListType.ofInts()) + node.output().replaceAllUsesWith(const_node.output()) - # onnx only supports tensors, so we turn all out number types into tensors - torch._C._jit_pass_erase_number_types(graph) + +def _optimize_graph(graph, operator_export_type): + _list_constant_prop(graph, graph) # run dce to eliminate dead parts of the graph that might have been # left behind by things like symbolic_override @@ -106,6 +119,11 @@ def _optimize_graph(graph, operator_export_type): torch._C._jit_pass_peephole(graph) torch._C._jit_pass_lint(graph) + + # onnx only supports tensors, so we turn all out number types into tensors + torch._C._jit_pass_erase_number_types(graph) + torch._C._jit_pass_peephole(graph) + if operator_export_type != OperatorExportTypes.RAW: graph = torch._C._jit_pass_onnx(graph, operator_export_type) torch._C._jit_pass_lint(graph) @@ -452,7 +470,14 @@ def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExpor elif ns == "prim": if op_name == "Constant": - return g.op("Constant", value_t=n["value"]) + if n.kindOf("value") == "t": + return g.op("Constant", value_t=n["value"]) + elif n.kindOf("value") == "is": + value = torch.stack([torch.tensor(v) for v in n["value"]]) if n["value"] else [] + return g.op("Constant", value_t=value) + else: + raise RuntimeError("Unsupported prim::Constant kind: `{}`. Send a bug report.".format( + n.kindOf("value"))) elif op_name == "ListConstruct": unsqueezed = [g.op("Unsqueeze", input, axes_i=[0]) for input in inputs] return g.op("Concat", *unsqueezed, axis_i=0) diff --git a/torch/tensor.py b/torch/tensor.py index 60a50b6b67b454..67e195466e8781 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -242,7 +242,7 @@ def btrifact(self, info=None, pivot=True): "consider using btrifact_with_info instead", stacklevel=2) factorization, pivots, _info = super(Tensor, self).btrifact_with_info(pivot=pivot) if info.type() != _info.type(): - raise ValueError('btrifact expects info to be an IntTenor') + raise ValueError('btrifact expects info to be an IntTensor') info.resize_as_(_info).copy_(_info) return factorization, pivots else: @@ -292,14 +292,6 @@ def scatter(self, dim, index, source): def scatter_add(self, dim, index, source): return self.clone().scatter_add_(dim, index, source) - def masked_copy(self, mask, tensor): - warnings.warn("masked_copy is deprecated and renamed to masked_scatter, and will be removed in v0.3") - return self.masked_scatter(mask, tensor) - - def masked_copy_(self, mask, tensor): - warnings.warn("masked_copy_ is deprecated and renamed to masked_scatter_, and will be removed in v0.3") - return self.masked_scatter_(mask, tensor) - def masked_scatter(self, mask, tensor): return self.clone().masked_scatter_(mask, tensor)