diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index b1dd9d348e8c0b..bc2762860dd2bd 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -9,6 +9,11 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh" echo "Testing pytorch" +if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then + echo "Skipping ROCm tests for now" + exit 0 +fi + # JIT C++ extensions require ninja. git clone https://github.com/ninja-build/ninja --quiet pushd ninja diff --git a/README.md b/README.md index 27872820ccde95..35e14867cbba64 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -

+![PyTorch Logo](https://github.com/pytorch/pytorch/blob/master/docs/source/_static/img/pytorch-logo-dark.png) -------------------------------------------------------------------------------- @@ -34,32 +34,14 @@ See also the [ci.pytorch.org HUD](https://ezyang.github.io/pytorch-ci-hud/build/ At a granular level, PyTorch is a library that consists of the following components: - - - - - - - - - - - - - - - - - - - - - - - - - -
torch a Tensor library like NumPy, with strong GPU support
torch.autograd a tape-based automatic differentiation library that supports all differentiable Tensor operations in torch
torch.nn a neural networks library deeply integrated with autograd designed for maximum flexibility
torch.multiprocessing Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training.
torch.utils DataLoader, Trainer and other utility functions for convenience
torch.legacy(.nn/.optim) legacy code that has been ported over from torch for backward compatibility reasons
+| Component | Description | +| ---- | --- | +| **torch** | a Tensor library like NumPy, with strong GPU support | +| **torch.autograd** | a tape-based automatic differentiation library that supports all differentiable Tensor operations in torch | +| **torch.nn** | a neural networks library deeply integrated with autograd designed for maximum flexibility | +| **torch.multiprocessing** | Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training | +| **torch.utils** | DataLoader, Trainer and other utility functions for convenience | +| **torch.legacy(.nn/.optim)** | legacy code that has been ported over from torch for backward compatibility reasons | Usually one uses PyTorch either as: @@ -72,7 +54,7 @@ Elaborating further: If you use NumPy, then you have used Tensors (a.k.a ndarray). -

+![Tensor illustration](https://github.com/pytorch/pytorch/blob/master/docs/source/_static/img/tensor_illustration.png) PyTorch provides Tensors that can live either on the CPU or the GPU, and accelerate compute by a huge amount. @@ -99,7 +81,7 @@ from several research papers on this topic, as well as current and past work suc While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date. You get the best of speed and flexibility for your crazy research. -

+![Dynamic graph](https://github.com/pytorch/pytorch/blob/master/docs/source/_static/img/dynamic_graph.gif) ### Python First diff --git a/aten/src/ATen/Retainable.h b/aten/src/ATen/Retainable.h index da0bdda4bf6865..792a2209afa09d 100644 --- a/aten/src/ATen/Retainable.h +++ b/aten/src/ATen/Retainable.h @@ -7,21 +7,52 @@ namespace at { // base class for refcounted things, allows for collects of generic // refcounted objects that include tensors struct Retainable { - Retainable(): refcount(1) {} + Retainable(): refcount(1), weak_refcount(1) {} void retain() { ++refcount; } void release() { if(--refcount == 0) { + // If we know that this is the last reference then we can skip + // all the decrements and release_resources(). + if (weak_refcount == 1) { + delete this; + } else { + release_resources(); + weak_release(); + } + } + } + void weak_retain() { + ++weak_refcount; + } + void weak_release() { + if (--weak_refcount == 0) { delete this; } } - int use_count() const { + bool weak_lock() { + for (;;) { + auto current_refcount = refcount.load(); + if (current_refcount == 0) return false; + if (refcount.compare_exchange_strong(current_refcount, current_refcount + 1)) break; + } + return true; + } + uint32_t use_count() const { return refcount.load(); } + uint32_t weak_use_count() const { + return weak_refcount.load(); + } + + virtual void release_resources() {}; virtual ~Retainable() {} private: - std::atomic refcount; + // INVARIANT: once refcount reaches 0 it can never go up + // INVARIANT: weak_refcount = number of weak references + (refcount > 0 ? 1 : 0) + std::atomic refcount; + std::atomic weak_refcount; }; } diff --git a/aten/src/ATen/TensorBase.h b/aten/src/ATen/TensorBase.h index 6f70e3df7391f5..3aea68ffdfdbcb 100644 --- a/aten/src/ATen/TensorBase.h +++ b/aten/src/ATen/TensorBase.h @@ -5,54 +5,62 @@ namespace at { namespace detail { -// TensorBase is the base class for Tensor which handles the reference counting -struct TensorBase { - TensorBase(): TensorBase(UndefinedTensor::singleton(), false) {} - TensorBase(TensorImpl * self, bool retain) +// TensorBaseImpl is the base class for Tensor which handles the reference counting +template +struct TensorBaseImpl { + TensorBaseImpl(): TensorBaseImpl(UndefinedTensor::singleton(), false) {} + TensorBaseImpl(TensorImpl * self, bool should_retain) : pImpl(self) { if (pImpl == nullptr) { - throw std::runtime_error("TensorBase with nullptr not supported"); + throw std::runtime_error("TensorBaseImpl with nullptr not supported"); + } + if(should_retain && pImpl != UndefinedTensor::singleton()) { + retain(); } - if(retain && pImpl != UndefinedTensor::singleton()) - pImpl->retain(); } - TensorBase(const TensorBase & rhs) + TensorBaseImpl(const TensorBaseImpl & rhs) : pImpl(rhs.pImpl) { - if (pImpl != UndefinedTensor::singleton()) - pImpl->retain(); + if (pImpl != UndefinedTensor::singleton()) { + retain(); + } } - TensorBase(TensorBase && rhs) noexcept + TensorBaseImpl(TensorBaseImpl && rhs) noexcept : pImpl(rhs.pImpl) { rhs.pImpl = UndefinedTensor::singleton(); } - ~TensorBase() { - if (pImpl != UndefinedTensor::singleton()) - pImpl->release(); + ~TensorBaseImpl() { + if (pImpl != UndefinedTensor::singleton()) { + release(); + } } - TensorBase & operator=(TensorBase && rhs) & { + TensorBaseImpl & operator=(TensorBaseImpl && rhs) & { rhs.swap(*this); return *this; } - TensorBase & operator=(TensorBase const & rhs) & { - //TensorBase ctor retains original rhs.pImpl - //then rhs.pImpl is swapped with this->pImpl - //finally TensorBase dtor releases rhs.pImpl, which was originally this->pImpl - TensorBase(rhs).swap(*this); - return *this; + TensorBaseImpl & operator=(TensorBaseImpl const & rhs) & { + //TensorBaseImpl ctor retains original rhs.pImpl + //then rhs.pImpl is swapped with this->pImpl + //finally TensorBaseImpl dtor releases rhs.pImpl, which was originally this->pImpl + TensorBaseImpl(rhs).swap(*this); + return *this; } int64_t dim() const { - return pImpl->dim(); + if (is_strong) { + return pImpl->dim(); + } else { + AT_ERROR("Can't call dim() on a WeakTensor"); + } } void reset() { - TensorBase().swap(*this); + TensorBaseImpl().swap(*this); } void reset(TensorImpl * rhs) { - TensorBase(rhs, true).swap(*this); + TensorBaseImpl(rhs, true).swap(*this); } - void reset(TensorImpl * rhs, bool retain) { - TensorBase(rhs, retain).swap(*this ); + void reset(TensorImpl * rhs, bool should_retain) { + TensorBaseImpl(rhs, should_retain).swap(*this ); } - void swap(TensorBase & rhs) { + void swap(TensorBaseImpl & rhs) { TensorImpl * tmp = pImpl; pImpl = rhs.pImpl; rhs.pImpl = tmp; @@ -75,6 +83,26 @@ struct TensorBase { //TODO(zach): sort out friend structes public: TensorImpl * pImpl; + +private: + void retain() { + if (is_strong) { + pImpl->retain(); + } else { + pImpl->weak_retain(); + } + } + + void release() { + if (is_strong) { + pImpl->release(); + } else { + pImpl->weak_release(); + } + } }; +using TensorBase = TensorBaseImpl; +using WeakTensorBase = TensorBaseImpl; + }} // namespace at::detail diff --git a/aten/src/ATen/Utils.h b/aten/src/ATen/Utils.h index 18b562130ce9d4..ccefa25497ab4b 100644 --- a/aten/src/ATen/Utils.h +++ b/aten/src/ATen/Utils.h @@ -12,11 +12,9 @@ #if defined(__clang__) #define __ubsan_ignore_float_divide_by_zero__ __attribute__((no_sanitize("float-divide-by-zero"))) -#define __ubsan_ignore_function__ __attribute__((no_sanitize("function"))) #define __ubsan_ignore_vptr__ __attribute__((no_sanitize("vptr"))) #else #define __ubsan_ignore_float_divide_by_zero__ -#define __ubsan_ignore_function__ #define __ubsan_ignore_vptr__ #endif diff --git a/aten/src/ATen/templates/Tensor.h b/aten/src/ATen/templates/Tensor.h index 884f583f768195..31e952ebb79ff8 100644 --- a/aten/src/ATen/templates/Tensor.h +++ b/aten/src/ATen/templates/Tensor.h @@ -13,6 +13,7 @@ #include "ATen/Utils.h" #include "ATen/Device.h" #include "ATen/Layout.h" +#include "ATen/optional.h" namespace at { struct Type; @@ -42,6 +43,7 @@ namespace at { // Note that Tensor can also be NULL, i.e. it is not associated with any underlying TensorImpl, and // special care must be taken to handle this. struct Tensor : public detail::TensorBase { + using TensorBase = detail::TensorBase; Tensor() : TensorBase() {} Tensor(TensorImpl * self, bool retain) : TensorBase(self, retain) {} Tensor(const TensorBase & rhs) : TensorBase(rhs) {} @@ -198,6 +200,46 @@ struct Tensor : public detail::TensorBase { auto m(F func, Args&&... params) const -> decltype(func(*this, std::forward(params)...)) { return func(*this, std::forward(params)...); } + + friend struct WeakTensor; +}; + +struct WeakTensor : public detail::WeakTensorBase { + using WeakTensorBase = detail::WeakTensorBase; + WeakTensor() : WeakTensorBase() {} + WeakTensor(TensorImpl * self, bool retain) : WeakTensorBase(self, retain) {} + WeakTensor(const WeakTensor & rhs) = default; + WeakTensor(WeakTensor && rhs) noexcept = default; + WeakTensor(const Tensor& t) : WeakTensorBase(t.pImpl, true) {} + + // reimplemented from TensorBase so the return type is WeakTensor rather than TensorBase + WeakTensor & operator=(WeakTensor && rhs) & { + rhs.swap(*this); + return *this; + } + WeakTensor & operator=(WeakTensor const & rhs) & { + //Tensor ctor retains original rhs.pImpl + //then rhs.pImpl is swapped with this->pImpl + //finally Tensor dtor releases rhs.pImpl, which was originally this->pImpl + WeakTensor(rhs).swap(*this); + return *this; + } + + WeakTensor & operator=(const Tensor& t) { + WeakTensor(t.pImpl, true).swap(*this); + return *this; + } + + // non-retaining + TensorImpl * unsafeGetTensorImpl() const { + return pImpl; + } + + // XXX: this can return undefined tensors + // Ideally it would be at::optional, but MSVC is too cool for that + Tensor lock() const { + return pImpl->weak_lock() ? Tensor(pImpl, false) : Tensor(); + } }; namespace detail { diff --git a/aten/src/ATen/templates/TensorDerived.cpp b/aten/src/ATen/templates/TensorDerived.cpp index 1560e65863d7b7..e15eb5fcb07dda 100644 --- a/aten/src/ATen/templates/TensorDerived.cpp +++ b/aten/src/ATen/templates/TensorDerived.cpp @@ -49,6 +49,11 @@ void * ${Tensor}::unsafeGetTH(bool retain) { return tensor; } +void ${Tensor}::release_resources() { + ${THTensor}_free(${state,} tensor); + tensor = nullptr; +} + ${TensorDenseOrSparse} } diff --git a/aten/src/ATen/templates/TensorDerived.h b/aten/src/ATen/templates/TensorDerived.h index 092d0634ca24c8..892d6bcca58276 100644 --- a/aten/src/ATen/templates/TensorDerived.h +++ b/aten/src/ATen/templates/TensorDerived.h @@ -23,6 +23,7 @@ struct ${Tensor} final : public TensorImpl { virtual Scalar localScalar() override; virtual void * unsafeGetTH(bool retain) override; virtual std::unique_ptr storage() override; + virtual void release_resources() override; static const char * typeString(); //TODO(zach): sort of friend permissions later so this diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index db2bf723050a8d..25d84a36af4af9 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -18,7 +18,8 @@ list(APPEND ATen_CPU_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/test_parallel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp) list(APPEND ATen_CUDA_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/integer_divider_test.cu diff --git a/aten/src/ATen/test/weakref_test.cpp b/aten/src/ATen/test/weakref_test.cpp new file mode 100644 index 00000000000000..aab2ec57d92b95 --- /dev/null +++ b/aten/src/ATen/test/weakref_test.cpp @@ -0,0 +1,64 @@ +#define CATCH_CONFIG_MAIN +#include "catch.hpp" + +#include "ATen/ATen.h" + +#include +#include +#include + +using at::Tensor; +using at::WeakTensor; + +TEST_CASE( "Weak pointer tests", "" ) { + SECTION("gets invalidated") { + Tensor a = at::ones({2, 2}); + WeakTensor b = a; + a.reset(); + REQUIRE_FALSE(b.lock().defined()); + } + + SECTION("can successfully lock") { + Tensor a = at::ones({2, 2}); + WeakTensor b = a; + auto c = b.lock(); + REQUIRE(c.defined()); + + a.reset(); + REQUIRE(b.lock().defined()); + c.reset(); + REQUIRE_FALSE(b.lock().defined()); + } + + SECTION("updates refcounts correctly") { + Tensor a = at::ones({2, 2}); + auto ai = a.unsafeGetTensorImpl(); + REQUIRE(ai->use_count() == 1); + REQUIRE(ai->weak_use_count() == 1); + { + WeakTensor b = a; + REQUIRE(ai->use_count() == 1); + REQUIRE(ai->weak_use_count() == 2); + } + REQUIRE(ai->use_count() == 1); + REQUIRE(ai->weak_use_count() == 1); + { + WeakTensor b = a; + REQUIRE(ai->use_count() == 1); + auto locked = b.lock(); + REQUIRE(locked.defined()); + REQUIRE(ai->use_count() == 2); + } + REQUIRE(ai->use_count() == 1); + REQUIRE(ai->weak_use_count() == 1); + { + WeakTensor b = a; + REQUIRE(ai->use_count() == 1); + REQUIRE(ai->weak_use_count() == 2); + a.reset(); + auto bi = b.unsafeGetTensorImpl(); + REQUIRE(bi->use_count() == 0); + REQUIRE(bi->weak_use_count() == 1); + } + } +} diff --git a/aten/src/TH/THStorage.cpp b/aten/src/TH/THStorage.cpp index 52fc5dd471a2eb..f4910c3f07fe32 100644 --- a/aten/src/TH/THStorage.cpp +++ b/aten/src/TH/THStorage.cpp @@ -27,9 +27,6 @@ void THStorage_free(THStorage *storage) { } storage->finalizer.~unique_ptr(); storage->data_ptr.~DataPtr(); - if (storage->flag & TH_STORAGE_VIEW) { - THStorage_free(storage->view); - } THStorage_weakFree(storage); } } @@ -227,6 +224,5 @@ void THStorage_swap(THStorage *storage1, THStorage *storage2) SWAP(flag); SWAP(allocator); SWAP(finalizer); - SWAP(view); #undef SWAP } diff --git a/aten/src/TH/THStorage.hpp b/aten/src/TH/THStorage.hpp index 303f89094094d4..e02e265062d94b 100644 --- a/aten/src/TH/THStorage.hpp +++ b/aten/src/TH/THStorage.hpp @@ -47,7 +47,6 @@ typedef struct THStorage char flag; at::Allocator *allocator; std::unique_ptr finalizer; - struct THStorage *view; template inline T * data() const { diff --git a/aten/src/TH/generic/THStorage.h b/aten/src/TH/generic/THStorage.h index 7212a64dffb1f7..4850c4746136db 100644 --- a/aten/src/TH/generic/THStorage.h +++ b/aten/src/TH/generic/THStorage.h @@ -22,7 +22,6 @@ #define TH_STORAGE_REFCOUNTED 1 #define TH_STORAGE_RESIZABLE 2 -#define TH_STORAGE_VIEW 8 // Struct definition is moved to THStorage.hpp (so this file stays C compatible) typedef struct THStorage THStorage; diff --git a/aten/src/THC/THCStorage.cpp b/aten/src/THC/THCStorage.cpp index dd3d56c437700d..ab92022d9a6a1d 100644 --- a/aten/src/THC/THCStorage.cpp +++ b/aten/src/THC/THCStorage.cpp @@ -55,9 +55,6 @@ void THCStorage_free(THCState *state, THCStorage *storage) } storage->finalizer.~unique_ptr(); storage->data_ptr.~DataPtr(); - if (storage->flag & TH_STORAGE_VIEW) { - THCStorage_free(state, storage->view); - } THStorage_weakFree(storage); } } diff --git a/caffe2/contrib/tensorrt/tensorrt_tranformer.cc b/caffe2/contrib/tensorrt/tensorrt_tranformer.cc index a1fbdeb258566d..2c0522438feefd 100644 --- a/caffe2/contrib/tensorrt/tensorrt_tranformer.cc +++ b/caffe2/contrib/tensorrt/tensorrt_tranformer.cc @@ -479,7 +479,7 @@ void TensorRTTransformer::Transform( auto trt_builder = tensorrt::TrtObject(nvinfer1::createInferBuilder(logger)); auto trt_network = tensorrt::TrtObject(trt_builder->createNetwork()); auto importer = - tensorrt::TrtObject(nvonnxparser::createParser(*trt_network, logger)); + tensorrt::TrtObject(nvonnxparser::createParser(trt_network.get(), logger)); // function to tell whether TensorRT supports a given C2 op or not auto supports = diff --git a/caffe2/contrib/tensorrt/trt_utils.cc b/caffe2/contrib/tensorrt/trt_utils.cc index 2c3e6e99282be9..f1efa4e3c57077 100644 --- a/caffe2/contrib/tensorrt/trt_utils.cc +++ b/caffe2/contrib/tensorrt/trt_utils.cc @@ -13,7 +13,7 @@ std::shared_ptr BuildTrtEngine( auto trt_builder = TrtObject(nvinfer1::createInferBuilder(*logger)); auto trt_network = TrtObject(trt_builder->createNetwork()); auto trt_parser = - TrtObject(nvonnxparser::createParser(*trt_network, *logger)); + TrtObject(nvonnxparser::createParser(trt_network.get(), *logger)); auto status = trt_parser->parse(onnx_model_str.data(), onnx_model_str.size()); if (!status) { const auto num_errors = trt_parser->getNbErrors(); diff --git a/caffe2/operators/bbox_transform_op.cc b/caffe2/operators/bbox_transform_op.cc index 5dde4b121dab53..0d2b5a3a9aa25a 100644 --- a/caffe2/operators/bbox_transform_op.cc +++ b/caffe2/operators/bbox_transform_op.cc @@ -54,6 +54,11 @@ Transform proposal bounding boxes to target bounding box using bounding box "angle_bound_hi", "int (default 90 degrees). If set, for rotated boxes, angle is " "normalized to be within [angle_bound_lo, angle_bound_hi].") + .Arg( + "clip_angle_thresh", + "float (default 1.0 degrees). For RRPN, clip almost horizontal boxes " + "within this threshold of tolerance for backward compatibility. " + "Set to negative value for no clipping.") .Input( 0, "rois", @@ -168,7 +173,8 @@ bool BBoxTransformOp::RunOnDevice() { angle_bound_on_, angle_bound_lo_, angle_bound_hi_); - EArrXXf clip_boxes = utils::clip_boxes(trans_boxes, img_h, img_w); + EArrXXf clip_boxes = + utils::clip_boxes(trans_boxes, img_h, img_w, clip_angle_thresh_); // Do not apply scale for angle in rotated boxes clip_boxes.leftCols(4) *= scale_after; new_boxes.block(offset, k * box_dim, num_rois, box_dim) = clip_boxes; diff --git a/caffe2/operators/bbox_transform_op.h b/caffe2/operators/bbox_transform_op.h index e57d90e0266cf3..8d76973576ccf8 100644 --- a/caffe2/operators/bbox_transform_op.h +++ b/caffe2/operators/bbox_transform_op.h @@ -29,7 +29,9 @@ class BBoxTransformOp final : public Operator { angle_bound_lo_( OperatorBase::GetSingleArgument("angle_bound_lo", -90)), angle_bound_hi_( - OperatorBase::GetSingleArgument("angle_bound_hi", 90)) { + OperatorBase::GetSingleArgument("angle_bound_hi", 90)), + clip_angle_thresh_( + OperatorBase::GetSingleArgument("clip_angle_thresh", 1.0)) { CAFFE_ENFORCE_EQ( weights_.size(), 4, @@ -59,6 +61,10 @@ class BBoxTransformOp final : public Operator { bool angle_bound_on_{true}; int angle_bound_lo_{-90}; int angle_bound_hi_{90}; + // For RRPN, clip almost horizontal boxes within this threshold of + // tolerance for backward compatibility. Set to negative value for + // no clipping. + float clip_angle_thresh_{1.0}; }; } // namespace caffe2 diff --git a/caffe2/operators/filler_op.cu b/caffe2/operators/filler_op.cu index 55db9a57756ed1..9df195a918b91a 100644 --- a/caffe2/operators/filler_op.cu +++ b/caffe2/operators/filler_op.cu @@ -1,6 +1,7 @@ #include #include "caffe2/core/context_gpu.h" #include "caffe2/operators/filler_op.h" +#include "caffe2/operators/operator_fallback_gpu.h" namespace caffe2 { @@ -63,5 +64,8 @@ REGISTER_CUDA_OPERATOR(GaussianFill, GaussianFillOp); REGISTER_CUDA_OPERATOR(XavierFill, XavierFillOp); REGISTER_CUDA_OPERATOR(MSRAFill, MSRAFillOp); REGISTER_CUDA_OPERATOR(RangeFill, RangeFillOp); +REGISTER_CUDA_OPERATOR( + LengthsRangeFill, + GPUFallbackOp>); } // namespace caffe2 diff --git a/caffe2/operators/filler_op_gpu.cc b/caffe2/operators/filler_op_gpu.cc deleted file mode 100644 index b34236018afb2c..00000000000000 --- a/caffe2/operators/filler_op_gpu.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "caffe2/core/context_gpu.h" -#include "caffe2/operators/filler_op.h" -#include "caffe2/operators/operator_fallback_gpu.h" - -namespace caffe2 { -REGISTER_CUDA_OPERATOR( - LengthsRangeFill, - GPUFallbackOp>); -} diff --git a/caffe2/operators/generate_proposals_op.cc b/caffe2/operators/generate_proposals_op.cc index dff52aa2ac2e22..0b4f3a6a9d7553 100644 --- a/caffe2/operators/generate_proposals_op.cc +++ b/caffe2/operators/generate_proposals_op.cc @@ -197,8 +197,8 @@ void GenerateProposalsOp::ProposalsForOneImage( // 2. clip proposals to image (may result in proposals with zero area // that will be removed in the next step) - // TODO (viswanath): Should we clip rotated boxes as well? - proposals = utils::clip_boxes(proposals, im_info[0], im_info[1]); + proposals = + utils::clip_boxes(proposals, im_info[0], im_info[1], clip_angle_thresh_); // 3. remove predicted boxes with either height or width < min_size auto keep = utils::filter_boxes(proposals, min_size, im_info); @@ -342,6 +342,29 @@ non-maximum suppression is applied to generate the final bounding boxes. .Arg("post_nms_topN", "(int) RPN_POST_NMS_TOP_N") .Arg("nms_thresh", "(float) RPN_NMS_THRESH") .Arg("min_size", "(float) RPN_MIN_SIZE") + .Arg( + "correct_transform_coords", + "bool (default false), Correct bounding box transform coordates," + " see bbox_transform() in boxes.py " + "Set to true to match the detectron code, set to false for backward" + " compatibility") + .Arg( + "angle_bound_on", + "bool (default true). If set, for rotated boxes, angle is " + "normalized to be within [angle_bound_lo, angle_bound_hi].") + .Arg( + "angle_bound_lo", + "int (default -90 degrees). If set, for rotated boxes, angle is " + "normalized to be within [angle_bound_lo, angle_bound_hi].") + .Arg( + "angle_bound_hi", + "int (default 90 degrees). If set, for rotated boxes, angle is " + "normalized to be within [angle_bound_lo, angle_bound_hi].") + .Arg( + "clip_angle_thresh", + "float (default 1.0 degrees). For RRPN, clip almost horizontal boxes " + "within this threshold of tolerance for backward compatibility. " + "Set to negative value for no clipping.") .Input(0, "scores", "Scores from conv layer, size (img_count, A, H, W)") .Input( 1, diff --git a/caffe2/operators/generate_proposals_op.h b/caffe2/operators/generate_proposals_op.h index c1ae4889e8931a..81f7d9ac43123f 100644 --- a/caffe2/operators/generate_proposals_op.h +++ b/caffe2/operators/generate_proposals_op.h @@ -84,7 +84,9 @@ class GenerateProposalsOp final : public Operator { angle_bound_lo_( OperatorBase::GetSingleArgument("angle_bound_lo", -90)), angle_bound_hi_( - OperatorBase::GetSingleArgument("angle_bound_hi", 90)) {} + OperatorBase::GetSingleArgument("angle_bound_hi", 90)), + clip_angle_thresh_( + OperatorBase::GetSingleArgument("clip_angle_thresh", 1.0)) {} ~GenerateProposalsOp() {} @@ -127,6 +129,10 @@ class GenerateProposalsOp final : public Operator { bool angle_bound_on_{true}; int angle_bound_lo_{-90}; int angle_bound_hi_{90}; + // For RRPN, clip almost horizontal boxes within this threshold of + // tolerance for backward compatibility. Set to negative value for + // no clipping. + float clip_angle_thresh_{1.0}; }; } // namespace caffe2 diff --git a/caffe2/operators/generate_proposals_op_test.cc b/caffe2/operators/generate_proposals_op_test.cc index d8e1021010aa32..3fb7ed92e90a6a 100644 --- a/caffe2/operators/generate_proposals_op_test.cc +++ b/caffe2/operators/generate_proposals_op_test.cc @@ -320,6 +320,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0) { // Similar to TestRealDownSampled but for rotated boxes with angle info. float angle = 0; float delta_angle = 0; + float clip_angle_thresh = 1.0; Workspace ws; OperatorDef def; @@ -407,33 +408,37 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0) { vector im_info{60, 80, 0.166667f}; // vector anchors{-38, -16, 53, 31, -120, -120, 135, 135}; - vector anchors{8, 8, 92, 48, angle, 8, 8, 256, 256, angle}; - - // Although angle == 0, the results aren't exactly the same as - // TestRealDownSampled because because clip_boxes() is not performed - // for RRPN style boxes. - ERMatXf rois_gt(13, 6); - rois_gt << 0, 6.55346, 25.3227, 253.447, 291.446, 0, 0, 55.3932, 33.3369, - 253.731, 289.158, 0, 0, 6.48163, 24.3478, 92.3015, 38.6944, 0, 0, 70.3089, - 26.7894, 92.3453, 38.5539, 0, 0, 22.3067, 26.7714, 92.3424, 38.5243, 0, 0, - 054.084, 26.8413, 92.3938, 38.798, 0, 0, 5.33962, 42.2077, 92.5497, - 38.2259, 0, 0, 6.36709, 58.24, 92.16, 37.4372, 0, 0, 69.65, 48.6713, - 92.1521, 37.3668, 0, 0, 20.4147, 44.4783, 91.7111, 34.0295, 0, 0, 033.079, - 41.5149, 92.3244, 36.4278, 0, 0, 41.8235, 037.291, 90.2815, 034.872, 0, 0, - 13.8486, 48.662, 88.7818, 28.875, 0; - vector rois_probs_gt{0.0266914, - 0.005621, - 0.00544219, - 0.00120544, - 0.00119208, - 0.00117182, - 0.000617993, - 0.000472735, - 6.09605e-05, - 1.05262e-05, - 8.91026e-06, - 9.29537e-09, - 1.13482e-10}; + // Anchors in [x_ctr, y_ctr, w, h, angle] format + vector anchors{7.5, 7.5, 92, 48, angle, 7.5, 7.5, 256, 256, angle}; + + // Results should exactly be the same as TestRealDownSampled since + // angle = 0 for all boxes and clip_angle_thresh > 0 (which means + // all horizontal boxes will be clipped to maintain backward compatibility). + ERMatXf rois_gt_xyxy(9, 5); + rois_gt_xyxy << 0, 0, 0, 79, 59, 0, 0, 5.0005703f, 51.6324f, 42.6950f, 0, + 24.13628387f, 7.51243401f, 79, 45.0663f, 0, 0, 7.50924301f, 67.4779f, + 45.0336, 0, 0, 23.09477997f, 50.61448669f, 59, 0, 0, 39.52141571f, + 51.44710541f, 59, 0, 23.57396317f, 29.98791885f, 79, 59, 0, 0, + 41.90219116f, 79, 59, 0, 0, 23.30098343f, 78.2413f, 58.7287f; + ERMatXf rois_gt(9, 6); + // Batch ID + rois_gt.block(0, 0, rois_gt.rows(), 1) = + ERMatXf::Constant(rois_gt.rows(), 1, 0.0); + // rois_gt in [x_ctr, y_ctr, w, h] format + rois_gt.block(0, 1, rois_gt.rows(), 4) = + boxes_xyxy_to_xywh(rois_gt_xyxy.block(0, 1, rois_gt.rows(), 4)); + // Angle + rois_gt.block(0, 5, rois_gt.rows(), 1) = + ERMatXf::Constant(rois_gt.rows(), 1, angle); + vector rois_probs_gt{2.66913995e-02f, + 5.44218998e-03f, + 1.20544003e-03f, + 1.19207997e-03f, + 6.17993006e-04f, + 4.72735002e-04f, + 6.09605013e-05f, + 1.50015003e-05f, + 8.91025957e-06f}; AddInput(vector{img_count, A, H, W}, scores, "scores", &ws); AddInput( @@ -450,6 +455,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0) { def.add_arg()->CopyFrom(MakeArgument("nms_thresh", 0.7f)); def.add_arg()->CopyFrom(MakeArgument("min_size", 16.0f)); def.add_arg()->CopyFrom(MakeArgument("correct_transform_coords", true)); + def.add_arg()->CopyFrom(MakeArgument("clip_angle_thresh", clip_angle_thresh)); unique_ptr op(CreateOperator(def, &ws)); EXPECT_NE(nullptr, op.get()); @@ -484,6 +490,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotated) { float angle = 45.0; float delta_angle = 0.174533; // 0.174533 radians -> 10 degrees float expected_angle = 55.0; + float clip_angle_thresh = 1.0; Workspace ws; OperatorDef def; @@ -588,6 +595,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotated) { def.add_arg()->CopyFrom(MakeArgument("nms_thresh", 0.7f)); def.add_arg()->CopyFrom(MakeArgument("min_size", 16.0f)); def.add_arg()->CopyFrom(MakeArgument("correct_transform_coords", true)); + def.add_arg()->CopyFrom(MakeArgument("clip_angle_thresh", clip_angle_thresh)); unique_ptr op(CreateOperator(def, &ws)); EXPECT_NE(nullptr, op.get()); diff --git a/caffe2/operators/generate_proposals_op_util_boxes.h b/caffe2/operators/generate_proposals_op_util_boxes.h index 440d141899e2f6..0c4c345d382cb1 100644 --- a/caffe2/operators/generate_proposals_op_util_boxes.h +++ b/caffe2/operators/generate_proposals_op_util_boxes.h @@ -192,23 +192,50 @@ EArrXXt bbox_transform( } } +template +EArrXXt bbox_xyxy_to_ctrwh( + const Eigen::ArrayBase& boxes) { + CAFFE_ENFORCE_EQ(boxes.cols(), 4); + + const auto& x1 = boxes.col(0); + const auto& y1 = boxes.col(1); + const auto& x2 = boxes.col(2); + const auto& y2 = boxes.col(3); + + EArrXXt ret(boxes.rows(), 4); + ret.col(0) = (x1 + x2) / 2.0; // x_ctr + ret.col(1) = (y1 + y2) / 2.0; // y_ctr + ret.col(2) = x2 - x1 + 1.0; // w + ret.col(3) = y2 - y1 + 1.0; // h + return ret; +} + +template +EArrXXt bbox_ctrwh_to_xyxy( + const Eigen::ArrayBase& boxes) { + CAFFE_ENFORCE_EQ(boxes.cols(), 4); + + const auto& x_ctr = boxes.col(0); + const auto& y_ctr = boxes.col(1); + const auto& w = boxes.col(2); + const auto& h = boxes.col(3); + + EArrXXt ret(boxes.rows(), 4); + ret.col(0) = x_ctr - (w - 1) / 2.0; // x1 + ret.col(1) = y_ctr - (h - 1) / 2.0; // y1 + ret.col(2) = x_ctr + (w - 1) / 2.0; // x2 + ret.col(3) = y_ctr + (h - 1) / 2.0; // y2 + return ret; +} + // Clip boxes to image boundaries // boxes: pixel coordinates of bounding box, size (M * 4) -// -// For rotated boxes with angle support (M * 5), we don't clip and just -// return early. It's tricky to make the entire rectangular box fit within the -// image and still be able to not leave out pixels of interest. -// We rely on upstream ops like RoIAlignRotated safely handling such cases. template -EArrXXt -clip_boxes(const Eigen::ArrayBase& boxes, int height, int width) { - CAFFE_ENFORCE(boxes.cols() == 4 || boxes.cols() == 5); - if (boxes.cols() == 5) { - // No clipping for rotated boxes. - // TODO (viswanath): Should this be implemented for backward compatibility - // with angle=0 case? - return boxes; - } +EArrXXt clip_boxes_upright( + const Eigen::ArrayBase& boxes, + int height, + int width) { + CAFFE_ENFORCE(boxes.cols() == 4); EArrXXt ret(boxes.rows(), boxes.cols()); @@ -224,6 +251,69 @@ clip_boxes(const Eigen::ArrayBase& boxes, int height, int width) { return ret; } +// Similar to clip_boxes_upright but handles rotated boxes with angle info. +// boxes: size (M, 5), format [ctr_x; ctr_y; width; height; angle (in degrees)] +// +// Clipping is only performed for boxes that are almost upright +// (within a given `angle_thresh` tolerance) to maintain backward compatibility +// for non-rotated boxes. +// +// We don't clip rotated boxes due to a couple of reasons: +// (1) There are potentially multiple ways to clip a rotated box to make it +// fit within the image. +// (2) It's tricky to make the entire rectangular box fit within the image and +// still be able to not leave out pixels of interest. +// Therefore, we rely on upstream ops like RoIAlignRotated safely handling this. +template +EArrXXt clip_boxes_rotated( + const Eigen::ArrayBase& boxes, + int height, + int width, + float angle_thresh = 1.0) { + CAFFE_ENFORCE(boxes.cols() == 5); + + const auto& angles = boxes.col(4); + + // Filter boxes that are upright (with a tolerance of angle_thresh) + EArrXXt upright_boxes; + const auto& indices = GetArrayIndices(angles.abs() <= angle_thresh); + GetSubArrayRows(boxes, AsEArrXt(indices), &upright_boxes); + + // Convert to [x1, y1, x2, y2] format and clip them + const auto& upright_boxes_xyxy = + bbox_ctrwh_to_xyxy(upright_boxes.leftCols(4)); + const auto& clipped_upright_boxes_xyxy = + clip_boxes_upright(upright_boxes_xyxy, height, width); + + // Convert back to [x_ctr, y_ctr, w, h, angle] and update upright boxes + upright_boxes.block(0, 0, upright_boxes.rows(), 4) = + bbox_xyxy_to_ctrwh(clipped_upright_boxes_xyxy); + + EArrXXt ret(boxes.rows(), boxes.cols()); + ret = boxes; + for (int i = 0; i < upright_boxes.rows(); ++i) { + ret.row(indices[i]) = upright_boxes.row(i); + } + return ret; +} + +// Clip boxes to image boundaries. +template +EArrXXt clip_boxes( + const Eigen::ArrayBase& boxes, + int height, + int width, + float angle_thresh = 1.0) { + CAFFE_ENFORCE(boxes.cols() == 4 || boxes.cols() == 5); + if (boxes.cols() == 4) { + // Upright boxes + return clip_boxes_upright(boxes, height, width); + } else { + // Rotated boxes with angle info + return clip_boxes_rotated(boxes, height, width, angle_thresh); + } +} + // Only keep boxes with both sides >= min_size and center within the image. // boxes: pixel coordinates of bounding box, size (M * 4) // im_info: [height, width, img_scale] diff --git a/caffe2/operators/generate_proposals_op_util_boxes_test.cc b/caffe2/operators/generate_proposals_op_util_boxes_test.cc index a8d4f4c327e649..f9ff7e94bba274 100644 --- a/caffe2/operators/generate_proposals_op_util_boxes_test.cc +++ b/caffe2/operators/generate_proposals_op_util_boxes_test.cc @@ -105,4 +105,33 @@ TEST(UtilsBoxesTest, TestBboxTransformRotatedNormalized) { EXPECT_NEAR((result.matrix() - result_gt).norm(), 0.0, 1e-2); } +TEST(UtilsBoxesTest, ClipRotatedBoxes) { + // Test utils::clip_boxes_rotated() + using EMatXf = Eigen::MatrixXf; + + int height = 800; + int width = 600; + EMatXf bbox(5, 5); + bbox << 20, 20, 200, 150, 0, // Horizontal + 20, 20, 200, 150, 0.5, // Almost horizontal + 20, 20, 200, 150, 30, // Rotated + 300, 300, 200, 150, 30, // Rotated + 579, 779, 200, 150, -0.5; // Almost horizontal + + // Test with no clipping + float angle_thresh = -1.0; + auto result = utils::clip_boxes(bbox.array(), height, width, angle_thresh); + EXPECT_NEAR((result.matrix() - bbox).norm(), 0.0, 1e-4); + + EMatXf result_gt(5, 5); + result_gt << 59.75, 47.25, 120.5, 95.5, 0, 59.75, 47.25, 120.5, 95.5, 0.5, 20, + 20, 200, 150, 30, 300, 300, 200, 150, 30, 539.25, 751.75, 120.5, 95.5, + -0.5; + + // Test clipping with tolerance + angle_thresh = 1.0; + result = utils::clip_boxes(bbox.array(), height, width, angle_thresh); + EXPECT_NEAR((result.matrix() - result_gt).norm(), 0.0, 1e-4); +} + } // namespace caffe2 diff --git a/caffe2/operators/onnxifi_op.h b/caffe2/operators/onnxifi_op.h index 965bf876c60ffa..3c5cd2dbc36e61 100644 --- a/caffe2/operators/onnxifi_op.h +++ b/caffe2/operators/onnxifi_op.h @@ -84,7 +84,7 @@ class OnnxifiOp final : public Operator { // should retry until it get consistent. For now, we don't do that. CAFFE_ENFORCE_EQ( lib_->onnxGetBackendIDs(nullptr, &num_backends_), - ONNXIFI_STATUS_SUCCESS); + ONNXIFI_STATUS_FALLBACK); CAFFE_ENFORCE_GT( num_backends_, 0, "At least 1 onnxifi backend should be available"); backend_ids_.resize(num_backends_); diff --git a/caffe2/operators/utility_ops.cu b/caffe2/operators/utility_ops.cu index 4dee30bd16bdaf..9e68790f0a262c 100644 --- a/caffe2/operators/utility_ops.cu +++ b/caffe2/operators/utility_ops.cu @@ -4,17 +4,97 @@ // and std::isinf are declared constexpr there and the nvidia // compiler throws an error because of it +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/flatten_op.h" +#include "caffe2/operators/minmax_ops.h" +#include "caffe2/operators/utility_ops.h" +#include "caffe2/utils/math.h" + #include #include #include #include #include -#include "caffe2/core/context_gpu.h" -#include "flatten_op.h" -#include "minmax_ops.h" -#include "utility_ops.h" namespace caffe2 { + +template <> +bool WeightedSumOp::RunOnDevice() { + if (Input(0).IsType()) { + return DoRunWithType(); + } else if (Input(0).IsType()) { + return DoRunWithType(); + } else { + CAFFE_THROW("Unsupported inputs"); + } + return false; +} + +template <> +bool SumOp::RunOnDevice() { + if (Input(0).IsType()) { + return DoRunWithType(); + } else if (Input(0).IsType()) { + return DoRunWithType(); + } else { + CAFFE_THROW("Unsupported inputs"); + } + return false; +} + +template <> +class CopyOnDeviceLikeOp + : public Operator { + public: + CopyOnDeviceLikeOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws) {} + USE_OPERATOR_FUNCTIONS(CUDAContext); + + bool RunOnDevice() override { + auto& input = Input(0); + auto* output = OperatorBase::Output>(0); + CUDAContext context(GetGPUIDForPointer(Input(1).raw_data())); + output->ResizeLike(input); + context.template CopyItems( + input.meta(), + input.size(), + input.raw_data(), + output->raw_mutable_data(input.meta())); + return true; + } +}; + +REGISTER_CUDA_OPERATOR(Print, PrintOp); +REGISTER_CUDA_OPERATOR(Flatten, FlattenOp); +REGISTER_CUDA_OPERATOR(FlattenToVec, FlattenToVecOp); +REGISTER_CUDA_OPERATOR(Alias, AliasOp); +REGISTER_CUDA_OPERATOR(ResizeLike, ResizeLikeOp); +REGISTER_CUDA_OPERATOR(Sum, SumOp); +REGISTER_CUDA_OPERATOR(WeightedSum, WeightedSumOp); + +// From CPU, copy it to whatever the current context +REGISTER_CUDA_OPERATOR( + CopyFromCPUInput, + CopyOp); + +// CopyGPUToCPU and CopyCPUToGPU should both be carried out in a cuda context, +// since gpu code will be involved. +REGISTER_CUDA_OPERATOR( + CopyGPUToCPU, + CopyOp); +REGISTER_CUDA_OPERATOR( + CopyCPUToGPU, + CopyOp); +// If we only specify Copy, we assume that it is a gpu to gpu copy - maybe +// involving different GPUs. +REGISTER_CUDA_OPERATOR(Copy, CopyOp); + +REGISTER_CUDA_OPERATOR( + CopyOnDeviceLike, + CopyOnDeviceLikeOp); + +REGISTER_CUDA_OPERATOR(UnsafeCoalesce, UnsafeCoalesceOp); + CAFFE_KNOWN_TYPE(const float*); REGISTER_CUDA_OPERATOR(EnsureDense, EnsureDenseOp); diff --git a/caffe2/operators/utility_ops_gpu.cc b/caffe2/operators/utility_ops_gpu.cc deleted file mode 100644 index 674f78077f6d90..00000000000000 --- a/caffe2/operators/utility_ops_gpu.cc +++ /dev/null @@ -1,85 +0,0 @@ -#include "caffe2/core/context_gpu.h" -#include "caffe2/operators/flatten_op.h" -#include "caffe2/operators/utility_ops.h" -#include "caffe2/utils/math.h" - -namespace caffe2 { - -template <> -bool WeightedSumOp::RunOnDevice() { - if (Input(0).IsType()) { - return DoRunWithType(); - } else if (Input(0).IsType()) { - return DoRunWithType(); - } else { - CAFFE_THROW("Unsupported inputs"); - } - return false; -} - -template <> -bool SumOp::RunOnDevice() { - if (Input(0).IsType()) { - return DoRunWithType(); - } else if (Input(0).IsType()) { - return DoRunWithType(); - } else { - CAFFE_THROW("Unsupported inputs"); - } - return false; -} - -template <> -class CopyOnDeviceLikeOp - : public Operator { - public: - CopyOnDeviceLikeOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - USE_OPERATOR_FUNCTIONS(CUDAContext); - - bool RunOnDevice() override { - auto& input = Input(0); - auto* output = OperatorBase::Output>(0); - CUDAContext context(GetGPUIDForPointer(Input(1).raw_data())); - output->ResizeLike(input); - context.template CopyItems( - input.meta(), - input.size(), - input.raw_data(), - output->raw_mutable_data(input.meta())); - return true; - } -}; - -REGISTER_CUDA_OPERATOR(Print, PrintOp); -REGISTER_CUDA_OPERATOR(Flatten, FlattenOp); -REGISTER_CUDA_OPERATOR(FlattenToVec, FlattenToVecOp); -REGISTER_CUDA_OPERATOR(Alias, AliasOp); -REGISTER_CUDA_OPERATOR(ResizeLike, ResizeLikeOp); -REGISTER_CUDA_OPERATOR(Sum, SumOp); -REGISTER_CUDA_OPERATOR(WeightedSum, WeightedSumOp); - -// From CPU, copy it to whatever the current context -REGISTER_CUDA_OPERATOR( - CopyFromCPUInput, - CopyOp); - -// CopyGPUToCPU and CopyCPUToGPU should both be carried out in a cuda context, -// since gpu code will be involved. -REGISTER_CUDA_OPERATOR( - CopyGPUToCPU, - CopyOp); -REGISTER_CUDA_OPERATOR( - CopyCPUToGPU, - CopyOp); -// If we only specify Copy, we assume that it is a gpu to gpu copy - maybe -// involving different GPUs. -REGISTER_CUDA_OPERATOR(Copy, CopyOp); - -REGISTER_CUDA_OPERATOR( - CopyOnDeviceLike, - CopyOnDeviceLikeOp); - -REGISTER_CUDA_OPERATOR(UnsafeCoalesce, UnsafeCoalesceOp); - -} // namespace caffe2 diff --git a/caffe2/opt/converter.cc b/caffe2/opt/converter.cc index cc5be29120220d..b4866618b4e607 100644 --- a/caffe2/opt/converter.cc +++ b/caffe2/opt/converter.cc @@ -42,6 +42,14 @@ std::vector getDilations(std::map argMap) { return dilations; } +int getGroup(std::map& argMap) { + if (argMap.count("group")) { + CAFFE_ENFORCE(argMap["group"].has_i() && "Invalid group argument"); + return static_cast(argMap["group"].i()); + } + return 1; +} + } // namespace namespace caffe2 { @@ -115,6 +123,8 @@ class ConvConverter : public Converter { c->setStrides(getStrides(argMap)); c->setPads(getPads(argMap)); c->setDilations(getDilations(argMap)); + c->setGroup(getGroup(argMap)); + return nnOp; } // Does not override default converter to OperatorDef diff --git a/caffe2/opt/optimize_ideep.cc b/caffe2/opt/optimize_ideep.cc index 5a6643c2aa67ae..d880987fc6891b 100644 --- a/caffe2/opt/optimize_ideep.cc +++ b/caffe2/opt/optimize_ideep.cc @@ -1,51 +1,387 @@ #include "caffe2/opt/optimize_ideep.h" #include "caffe2/opt/converter.h" #include "caffe2/opt/fusion.h" -#include "caffe2/utils/proto_utils.h" + +#ifdef CAFFE2_USE_IDEEP +#include "caffe2/ideep/ideep_utils.h" +#endif namespace caffe2 { namespace opt { using namespace nom; -void OptimizeForIdeep(repr::NNModule* nn) { - // Conv+Relu fusion - auto should_fuse = [](const repr::Conv& conv) { - const auto annotation = conv.getAnnotation(); - if (!annotation || !isa(annotation)) { - return false; +#ifndef CAFFE2_USE_IDEEP +void OptimizeForIdeep( + repr::NNModule* nn, + caffe2::Workspace* ws, + bool training_mode) { + LOG(WARNING) << "Only support optimizations for IDEEP"; +} + +#else +USE_IDEEP_DEF_ALIASES(); + +Blob* getBlob(repr::NNGraph::NodeRef node, caffe2::Workspace* ws) { + auto tensor = repr::nn::get(node); + CAFFE_ENFORCE(ws->HasBlob(tensor->getName()), "Blob not in workspace"); + return ws->GetBlob(tensor->getName()); +} + +template +T* getTensor(Blob* blob) { + CAFFE_ENFORCE(blob, "Blob is invalid"); + if (blob && blob->template IsType()) { + return blob->template GetMutable(); + } + return nullptr; +} + +const caffe2::OperatorDef& getOpDef(const repr::NeuralNetOperator& nnOp) { + auto annotation = nnOp.getAnnotation(); + if (annotation == nullptr) { + CAFFE_THROW("Cannot get Operator annotation"); + } + return dyn_cast(annotation)->getOperatorDef(); +} + +caffe2::OperatorDef* getMutableOpDef(repr::NeuralNetOperator& nnOp) { + auto annotation = nnOp.getMutableAnnotation(); + if (annotation == nullptr) { + CAFFE_THROW("Cannot get Operator annotation"); + } + return dyn_cast(annotation)->getMutableOperatorDef(); +} + +bool isOnIdeepDevice(const repr::NeuralNetOperator& nnOp) { + // We only want to fuse for IDEEP convs + const auto& op = getOpDef(nnOp); + return op.device_option().device_type() == DeviceType::IDEEP; +} + +bool shouldFuseConv(const repr::Conv& conv) { + return isOnIdeepDevice(conv) ? (conv.getGroup() <= 1) : false; +} + +void resetConvForFusion(repr::NNGraph::NodeRef convNode, int fusion_type) { + // Fusion types: + // FUSION_CONV_RELU = 1 + // FUSION_CONV_SUM = 2 + // FUSION_CONV_SUM_RELU = 3 + auto conv = repr::nn::get(convNode); + auto annotation = conv->getMutableAnnotation(); + if (!annotation || !isa(annotation)) { + return; + } + + auto* op = getMutableOpDef(*conv); + if (op == nullptr) { + return; + } + + if (op->type() == "ConvFusion") { + CAFFE_ENFORCE(fusion_type == 1, "Invalid nest fusion"); + for (auto& arg : *op->mutable_arg()) { + if (arg.name() == "fusion_type") { + // Only from FUSION_CONV_SUM to FUSION_CONV_SUM_RELU + CAFFE_ENFORCE(arg.i() == 2, "Invalid nest fusion"); + arg.set_i(3); + return; + } + } + return; + } + + CAFFE_ENFORCE(fusion_type < 3, "Invalid fusion type"); + op->set_type("ConvFusion"); + auto* arg = op->add_arg(); + arg->set_name("fusion_type"); + arg->set_i(fusion_type); +} + +bool fuseConvBNHelperForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) { + for (auto node_pair : repr::nn::dataIterator(nn->dataFlow)) { + bool no_bias = false; + repr::NNGraph::NodeRef convNode; + repr::Conv* conv; + std::tie(conv, convNode) = node_pair; + + if (!isOnIdeepDevice(*conv)) { + LOG(WARNING) << "Not a IDEEP operator"; + continue; } - const auto& op = dyn_cast(annotation)->getOperatorDef(); - // We only want to fuse for IDEEP convs - if (op.device_option().device_type() != DeviceType::IDEEP) { - return false; + const auto& op = getOpDef(*conv); + if (op.type() == "ConvFusion") { + continue; } - // IDEEP doesn't support fusion group conv - int group = - ArgumentHelper::GetSingleArgument(op, "group", 1); - if (group != 1) { - return false; + auto convOutput = repr::nn::getOutputs(convNode).front(); + auto consumers = repr::nn::getConsumers(convOutput); + // convOutput is NOT referenced by sequential ops after BN. + if (consumers.size() != 1) { + continue; } + auto consumer = consumers.front(); + if (!repr::nn::is(consumer)) { + continue; + } + auto bnNode = consumer; + auto bn = repr::nn::get(bnNode); + auto bnOutput = repr::nn::getOutputs(bnNode).front(); + + auto convInputs = repr::nn::getInputs(convNode); + if (convInputs.size() < 2) { + LOG(WARNING) << "Invalid convolution input size"; + continue; + } + + auto bnInputs = repr::nn::getInputs(bnNode); + if (bnInputs.size() < 5) { + LOG(WARNING) << "Invalid batch normalization input size"; + continue; + } + + // When no bias, borrow BN bias + if (convInputs.size() < 3) { + no_bias = true; + nn->dataFlow.createEdge(bnInputs[2], convNode); + convInputs = repr::nn::getInputs(convNode); + } + +#define EXPOSE_TENSOR_DATA(name, index, nodes) \ + auto* name = getTensor(getBlob(nodes[index], ws)); \ + if (name == nullptr) { \ + LOG(WARNING) << #name " not a IDEEP tensor"; \ + continue; \ + } \ + itensor name##Tensor({name->get_dims(), name->get_data_type()}); \ + name##Tensor.reorder_from(*name); \ + CAFFE_ENFORCE( \ + name##Tensor.is_public_format(), #name " not with public format"); \ + auto* name##Data = static_cast(name##Tensor.get_data_handle()); + + EXPOSE_TENSOR_DATA(filter, 1, convInputs); + EXPOSE_TENSOR_DATA(biasConv, 2, convInputs); + + EXPOSE_TENSOR_DATA(scale, 1, bnInputs); + EXPOSE_TENSOR_DATA(biasBN, 2, bnInputs); + EXPOSE_TENSOR_DATA(mean, 3, bnInputs); + EXPOSE_TENSOR_DATA(variance, 4, bnInputs); + +#undef EXPOSE_TENSOR_DATA + + // Assume M{CHW,HWC} + auto chwDim = filterTensor.get_dim(1) * filterTensor.get_dim(2) * + filterTensor.get_dim(3); + for (auto c = 0; c < filterTensor.get_dim(0); ++c) { + float coeff = + scaleData[c] / std::sqrt(varianceData[c] + bn->getEpsilon()); + for (auto i = 0; i < chwDim; ++i) { + filterData[c * chwDim + i] *= coeff; + } + if (no_bias) { + biasConvData[c] = biasBNData[c] - meanData[c] * coeff; + } else { + biasConvData[c] = + biasBNData[c] + (biasConvData[c] - meanData[c]) * coeff; + } + } + + filter->reorder_from(filterTensor); + biasConv->reorder_from(biasConvTensor); + nn->dataFlow.replaceNode(convOutput, bnOutput); + + nn->dataFlow.deleteNode(bnNode); + nn->dataFlow.deleteNode(convOutput); + return true; - }; - auto postprocess = [](repr::NNGraph::NodeRef conv_node) { - auto conv = repr::nn::get(conv_node); - auto annotation = conv->getMutableAnnotation(); - if (!annotation || !isa(annotation)) { - return; + } + + return false; +} + +void fuseConvBNForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) { + while (fuseConvBNHelperForIdeep(nn, ws)) { + } +} + +void fuseConvSumForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) { + // Assume the order of nodes from getMutableNodes conforms to + // the original topo order of operators + auto allNodes = nn->dataFlow.getMutableNodes(); + for (int i = 0; i < allNodes.size(); i++) { + auto sumNode = allNodes[i]; + if (!repr::nn::hasInputs(sumNode)) { + continue; + } + + if (!repr::nn::is(sumNode)) { + continue; } - auto* op = dyn_cast(annotation)->getMutableOperatorDef(); - op->set_type("ConvFusion"); - auto* arg = op->add_arg(); - arg->set_name("fusion_type"); - // 1 means FUSION_CONV_RELU - arg->set_i(1); - }; + + auto sum = repr::nn::get(sumNode); + if (!isOnIdeepDevice(*sum)) { + LOG(WARNING) << "Not a IDEEP operator"; + continue; + } + + auto sumInputs = repr::nn::getInputs(sumNode); + if (sumInputs.size() != 2) { + continue; + } + + bool should_fuse = true; + for (auto input : sumInputs) { + auto consumer = repr::nn::getConsumers(input).back(); + if (consumer != sumNode) { + should_fuse = false; + break; + } + } + // Sum inputs should not be referenced by sequential ops. + if (!should_fuse) { + continue; + } + + int j = i - 1; + repr::NNGraph::NodeRef convNode = nullptr; + while (j-- >= 0) { + if (!repr::nn::hasInputs(sumNode)) { + continue; + } + + // Find the nearest Op before Sum + if (repr::nn::is(allNodes[j])) { + // The Op must be a Conv + if (repr::nn::is(allNodes[j])) { + convNode = allNodes[j]; + } + break; + } + } + if (convNode == nullptr) { + continue; + } + + auto conv = repr::nn::get(convNode); + if (!shouldFuseConv(*conv)) { + LOG(WARNING) << "Not a IDEEP operator"; + continue; + } + + auto convOutput = repr::nn::getOutputs(convNode).front(); + repr::NNGraph::NodeRef sumInputX = + (sumInputs[0] == convOutput ? sumInputs[1] : sumInputs[0]); + CAFFE_ENFORCE(sumInputX != nullptr, "Invalid sum inputs"); + + auto preNode = repr::nn::getProducer(sumInputX); + if (preNode == nullptr || !repr::nn::is(preNode)) { + LOG(WARNING) << "Can not fuse Conv Sum"; + continue; + } + + auto newOutputName = repr::nn::get(sumInputX)->getName(); + auto newOutputTensor = util::make_unique(newOutputName); + auto newOutput = nn->dataFlow.createNode( + unique_dyn_cast(newOutputTensor)); + + auto sumOutput = repr::nn::getOutputs(sumNode).front(); + nn->dataFlow.replaceNode(sumOutput, newOutput); + + // 2 means FUSION_CONV_SUM + resetConvForFusion(convNode, 2); + nn->dataFlow.createEdge(sumInputX, convNode); + nn->dataFlow.createEdge(convNode, newOutput); + + nn->dataFlow.deleteNode(sumNode); + nn->dataFlow.deleteNode(sumOutput); + nn->dataFlow.deleteNode(convOutput); + } +} + +void fuseActivationForIdeep(repr::NNModule* nn) { + // Conv+Relu fusion + auto should_fuse = shouldFuseConv; + auto postprocess = std::bind(resetConvForFusion, std::placeholders::_1, 1); fuseActivation(nn, should_fuse, postprocess); } +void enforceFusionInplaceForIdeep(repr::NNModule* nn) { + // For fusions of Conv+Sum or Conv+Sum+ReLU, the last input and output must + // be inplaced. To enforce inplace, here to re-check whole graph and correct + // the ConvFusion Ops. + for (auto node_pair : repr::nn::dataIterator(nn->dataFlow)) { + repr::NNGraph::NodeRef convNode; + repr::Conv* conv; + std::tie(conv, convNode) = node_pair; + + if (!isOnIdeepDevice(*conv)) { + LOG(WARNING) << "Not a IDEEP operator"; + continue; + } + + const auto& op = getOpDef(*conv); + if (op.type() != "ConvFusion") { + continue; + } + + bool enforce_inplace = false; + for (const auto& arg : op.arg()) { + // Only check FUSION_SUM & FUSION_SUM_RELU + if (arg.name() == "fusion_type" && (arg.i() == 2 || arg.i() == 3)) { + enforce_inplace = true; + break; + } + } + + if (!enforce_inplace) { + continue; + } + + auto convInput = repr::nn::getInputs(convNode).back(); + auto inputName = repr::nn::get(convInput)->getName(); + auto convOutput = repr::nn::getOutputs(convNode).front(); + auto outputName = repr::nn::get(convOutput)->getName(); + if (inputName == outputName) { + continue; + } + + auto consumer = repr::nn::getConsumers(convInput).back(); + if (consumer != convNode) { + LOG(ERROR) << "Can not enforce to inplace for fusion"; + return; + } + + auto newOutputTensor = util::make_unique(inputName); + auto newOutput = nn->dataFlow.createNode( + unique_dyn_cast(newOutputTensor)); + nn->dataFlow.replaceNode(convOutput, newOutput); + + nn->dataFlow.deleteNode(convOutput); + } +} + +void OptimizeForIdeep( + repr::NNModule* nn, + caffe2::Workspace* ws, + bool training_mode) { + if (training_mode) { + // Only support inference so far + return; + } + + fuseConvBNForIdeep(nn, ws); + + fuseConvSumForIdeep(nn, ws); + + fuseActivationForIdeep(nn); + + enforceFusionInplaceForIdeep(nn); +} + +#endif // CAFFE2_USE_IDEEP + } // namespace opt } // namespace caffe2 diff --git a/caffe2/opt/optimize_ideep.h b/caffe2/opt/optimize_ideep.h index 2e05ba4f095ef1..24635785336e57 100644 --- a/caffe2/opt/optimize_ideep.h +++ b/caffe2/opt/optimize_ideep.h @@ -1,13 +1,16 @@ #pragma once #include "caffe2/core/common.h" +#include "caffe2/core/workspace.h" #include "caffe2/proto/caffe2.pb.h" #include "nomnigraph/Representations/NeuralNet.h" namespace caffe2 { namespace opt { -void OptimizeForIdeep(nom::repr::NNModule* nn); - +void OptimizeForIdeep( + nom::repr::NNModule* nn, + caffe2::Workspace* ws, + bool training_mode = false); } } // namespace caffe2 diff --git a/caffe2/python/ideep/conv_op_test.py b/caffe2/python/ideep/conv_op_test.py index 7bb9fd37e1d40f..352b88995757ce 100644 --- a/caffe2/python/ideep/conv_op_test.py +++ b/caffe2/python/ideep/conv_op_test.py @@ -7,7 +7,9 @@ import hypothesis.strategies as st from hypothesis import given, settings import numpy as np +from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace +from caffe2.python.transformations import optimizeForIDEEP import caffe2.python.hypothesis_test_util as hu import caffe2.python.ideep_test_util as mu @@ -63,12 +65,60 @@ def test_depthwise_convolution(self, batch_size, gc, dc): pad=0, kernel=1, group=4, + device_option=dc[0] + ) + op1 = core.CreateOperator( + "Conv", + ["X", "w", "b"], + ["Y"], + stride=1, + pad=0, + kernel=1, + group=4, + device_option=dc[1] ) X = np.random.rand(batch_size, 544, 14, 14).astype(np.float32) w = np.random.rand(544, 136, 1, 1).astype(np.float32) b = np.random.rand(544).astype(np.float32) - inputs = [X, w, b] - self.assertDeviceChecks(dc, op, inputs, [0]) + + workspace.SwitchWorkspace("_device_check_", True) + workspace.FeedBlob('X', X, dc[0]) + workspace.FeedBlob('w', w, dc[0]) + workspace.FeedBlob('b', b, dc[0]) + workspace.RunOperatorOnce(op) + Y0 = workspace.FetchBlob('Y') + + workspace.ResetWorkspace() + workspace.FeedBlob('X', X, dc[1]) + workspace.FeedBlob('w', w, dc[1]) + workspace.FeedBlob('b', b, dc[1]) + net = core.Net("net") + old_net = caffe2_pb2.NetDef() + old_net.op.extend([op1]) + net.Proto().CopyFrom(old_net) + optimizeForIDEEP(net) + workspace.RunOperatorOnce(net.Proto().op[0]) + Y1 = workspace.FetchBlob('Y') + + if not np.allclose(Y0, Y1, atol=0.01, rtol=0.01): + print(Y1.flatten()) + print(Y0.flatten()) + print(np.max(np.abs(Y1 - Y0))) + self.assertTrue(False) + + workspace.ResetWorkspace() + workspace.FeedBlob('X', X, dc[1]) + workspace.FeedBlob('w', w, dc[1]) + workspace.FeedBlob('b', b, dc[1]) + workspace.RunOperatorOnce(op1) + Y2 = workspace.FetchBlob('Y') + + if not np.allclose(Y0, Y2, atol=0.01, rtol=0.01): + print(Y2.flatten()) + print(Y0.flatten()) + print(np.max(np.abs(Y2 - Y0))) + self.assertTrue(False) + if __name__ == "__main__": unittest.main() diff --git a/caffe2/python/ideep/convfusion_op_test.py b/caffe2/python/ideep/convfusion_op_test.py index 2a032924446b8d..0863e2b6e1125d 100644 --- a/caffe2/python/ideep/convfusion_op_test.py +++ b/caffe2/python/ideep/convfusion_op_test.py @@ -47,7 +47,7 @@ def test_convolution_relu_fusion(self, stride, pad, kernel, size, device_option=dc[0] ) - # Manual fusion + # Manual fusion for Conv + ReLU conv_fusion = core.CreateOperator( "ConvFusion", ["X1", "w1", "b1"] if use_bias else ["X1", "w1"], @@ -60,21 +60,6 @@ def test_convolution_relu_fusion(self, stride, pad, kernel, size, device_option=dc[1] ) - # Auto fusion - old_net = caffe2_pb2.NetDef() - conv_old = caffe2_pb2.OperatorDef() - conv_old.CopyFrom(conv) - conv_old.device_option.CopyFrom(dc[1]) - relu_old = caffe2_pb2.OperatorDef() - relu_old.CopyFrom(relu) - relu_old.device_option.CopyFrom(dc[1]) - old_net.op.extend([conv_old, relu_old]) - net = core.Net("net") - net.Proto().CopyFrom(old_net) - optimizeForIDEEP(net) - self.assertTrue(len(net.Proto().op) == 1) - self.assertTrue(net.Proto().op[0].type == "ConvFusion") - X = np.random.rand( batch_size, input_channels * group, size, size).astype(np.float32) - 0.5 w = np.random.rand( @@ -103,10 +88,24 @@ def test_convolution_relu_fusion(self, stride, pad, kernel, size, print(np.max(np.abs(Y1 - Y0))) self.assertTrue(False) + # Auto fusion for Conv + ReLU workspace.ResetWorkspace() + old_net = caffe2_pb2.NetDef() + conv_old = caffe2_pb2.OperatorDef() + conv_old.CopyFrom(conv) + conv_old.device_option.CopyFrom(dc[1]) + relu_old = caffe2_pb2.OperatorDef() + relu_old.CopyFrom(relu) + relu_old.device_option.CopyFrom(dc[1]) + old_net.op.extend([conv_old, relu_old]) workspace.FeedBlob('X0', X, dc[1]) workspace.FeedBlob('w0', w, dc[1]) workspace.FeedBlob('b0', b, dc[1]) + net = core.Net("net") + net.Proto().CopyFrom(old_net) + optimizeForIDEEP(net) + self.assertTrue(len(net.Proto().op) == 1) + self.assertTrue(net.Proto().op[0].type == "ConvFusion") workspace.RunOperatorOnce(net.Proto().op[0]) Y2 = workspace.FetchBlob('Y0') if not np.allclose(Y0, Y2, atol=0.01, rtol=0.01): @@ -130,6 +129,12 @@ def test_convolution_relu_fusion(self, stride, pad, kernel, size, def test_convolution_sum_fusion(self, stride, pad, kernel, size, input_channels, output_channels, batch_size, use_bias, group, gc, dc): + relu_S0 = core.CreateOperator( + "Relu", + ["S0"], + ["S0"], + device_option=dc[0] + ) conv = core.CreateOperator( "Conv", ["X0", "w0", "b0"] if use_bias else ["X0", "w0"], @@ -146,6 +151,14 @@ def test_convolution_sum_fusion(self, stride, pad, kernel, size, ["S0"], device_option=dc[0] ) + + # Manual fusion for Conv + Sum + relu_S1 = core.CreateOperator( + "Relu", + ["S1"], + ["S1"], + device_option=dc[1] + ) conv_fusion = core.CreateOperator( "ConvFusion", ["X1", "w1", "b1", "S1"] if use_bias else ["X1", "w1", "S1"], @@ -173,6 +186,7 @@ def test_convolution_sum_fusion(self, stride, pad, kernel, size, Y0 = workspace.FetchBlob('Y0') S = np.random.rand(*Y0.shape).astype(np.float32) - 0.5 workspace.FeedBlob('S0', S, dc[0]) + workspace.RunOperatorOnce(relu_S0) workspace.RunOperatorOnce(sum) S0 = workspace.FetchBlob('S0') @@ -181,6 +195,7 @@ def test_convolution_sum_fusion(self, stride, pad, kernel, size, workspace.FeedBlob('w1', w, dc[1]) workspace.FeedBlob('b1', b, dc[1]) workspace.FeedBlob('S1', S, dc[1]) + workspace.RunOperatorOnce(relu_S1) workspace.RunOperatorOnce(conv_fusion) S1 = workspace.FetchBlob('S1') @@ -189,6 +204,37 @@ def test_convolution_sum_fusion(self, stride, pad, kernel, size, print(S0.flatten()) print(np.max(np.abs(S1 - S0))) self.assertTrue(False) + + # Auto fusion for Conv + Sum + workspace.ResetWorkspace() + old_net = caffe2_pb2.NetDef() + relu_S0_old = caffe2_pb2.OperatorDef() + relu_S0_old.CopyFrom(relu_S0) + relu_S0_old.device_option.CopyFrom(dc[1]) + conv_old = caffe2_pb2.OperatorDef() + conv_old.CopyFrom(conv) + conv_old.device_option.CopyFrom(dc[1]) + sum_old = caffe2_pb2.OperatorDef() + sum_old.CopyFrom(sum) + sum_old.device_option.CopyFrom(dc[1]) + old_net.op.extend([relu_S0_old, conv_old, sum_old]) + workspace.FeedBlob('X0', X, dc[1]) + workspace.FeedBlob('w0', w, dc[1]) + workspace.FeedBlob('b0', b, dc[1]) + workspace.FeedBlob('S0', S, dc[1]) + net = core.Net("net") + net.Proto().CopyFrom(old_net) + optimizeForIDEEP(net) + self.assertTrue(len(net.Proto().op) == 2) + self.assertTrue(net.Proto().op[1].type == "ConvFusion") + workspace.RunNetOnce(net.Proto()) + S2 = workspace.FetchBlob('S0') + if not np.allclose(S0, S2, atol=0.01, rtol=0.01): + print(S2.flatten()) + print(S0.flatten()) + print(np.max(np.abs(S2 - S0))) + self.assertTrue(False) + workspace.SwitchWorkspace(old_ws_name) @given(stride=st.integers(1, 3), @@ -204,6 +250,12 @@ def test_convolution_sum_fusion(self, stride, pad, kernel, size, def test_convolution_sum_relu_fusion(self, stride, pad, kernel, size, input_channels, output_channels, batch_size, use_bias, group, gc, dc): + relu_S0 = core.CreateOperator( + "Relu", + ["S0"], + ["S0"], + device_option=dc[0] + ) conv = core.CreateOperator( "Conv", ["X0", "w0", "b0"] if use_bias else ["X0", "w0"], @@ -226,6 +278,14 @@ def test_convolution_sum_relu_fusion(self, stride, pad, kernel, size, ["S0"], device_option=dc[0] ) + + # Manual fusion for Conv + Sum + ReLU + relu_S1 = core.CreateOperator( + "Relu", + ["S1"], + ["S1"], + device_option=dc[1] + ) conv_fusion = core.CreateOperator( "ConvFusion", ["X1", "w1", "b1", "S1"] if use_bias else ["X1", "w1", "S1"], @@ -253,6 +313,7 @@ def test_convolution_sum_relu_fusion(self, stride, pad, kernel, size, Y0 = workspace.FetchBlob('Y0') S = np.random.rand(*Y0.shape).astype(np.float32) - 0.5 workspace.FeedBlob('S0', S, dc[0]) + workspace.RunOperatorOnce(relu_S0) workspace.RunOperatorOnce(sum) workspace.RunOperatorOnce(relu) S0 = workspace.FetchBlob('S0') @@ -262,6 +323,7 @@ def test_convolution_sum_relu_fusion(self, stride, pad, kernel, size, workspace.FeedBlob('w1', w, dc[1]) workspace.FeedBlob('b1', b, dc[1]) workspace.FeedBlob('S1', S, dc[1]) + workspace.RunOperatorOnce(relu_S1) workspace.RunOperatorOnce(conv_fusion) S1 = workspace.FetchBlob('S1') @@ -270,6 +332,40 @@ def test_convolution_sum_relu_fusion(self, stride, pad, kernel, size, print(S0.flatten()) print(np.max(np.abs(S1 - S0))) self.assertTrue(False) + + # Auto fusion for Conv + Sum + ReLU + workspace.ResetWorkspace() + old_net = caffe2_pb2.NetDef() + relu_S0_old = caffe2_pb2.OperatorDef() + relu_S0_old.CopyFrom(relu_S0) + relu_S0_old.device_option.CopyFrom(dc[1]) + conv_old = caffe2_pb2.OperatorDef() + conv_old.CopyFrom(conv) + conv_old.device_option.CopyFrom(dc[1]) + sum_old = caffe2_pb2.OperatorDef() + sum_old.CopyFrom(sum) + sum_old.device_option.CopyFrom(dc[1]) + relu_old = caffe2_pb2.OperatorDef() + relu_old.CopyFrom(relu) + relu_old.device_option.CopyFrom(dc[1]) + old_net.op.extend([relu_S0_old, conv_old, sum_old, relu_old]) + workspace.FeedBlob('X0', X, dc[1]) + workspace.FeedBlob('w0', w, dc[1]) + workspace.FeedBlob('b0', b, dc[1]) + workspace.FeedBlob('S0', S, dc[1]) + net = core.Net("net") + net.Proto().CopyFrom(old_net) + optimizeForIDEEP(net) + self.assertTrue(len(net.Proto().op) == 2) + self.assertTrue(net.Proto().op[1].type == "ConvFusion") + workspace.RunNetOnce(net.Proto()) + S2 = workspace.FetchBlob('S0') + if not np.allclose(S0, S2, atol=0.01, rtol=0.01): + print(S2.flatten()) + print(S0.flatten()) + print(np.max(np.abs(S2 - S0))) + self.assertTrue(False) + workspace.SwitchWorkspace(old_ws_name) if __name__ == "__main__": diff --git a/caffe2/python/onnx/test_onnxifi.py b/caffe2/python/onnx/test_onnxifi.py index 39278c95117adf..002287cf3b839c 100644 --- a/caffe2/python/onnx/test_onnxifi.py +++ b/caffe2/python/onnx/test_onnxifi.py @@ -14,7 +14,7 @@ from caffe2.python.onnx.tests.test_utils import TestCase class OnnxifiTest(TestCase): - @unittest.skipIf(not workspace.C.use_trt, "No TensortRT support") + @unittest.skip("Need ONNXIFI backend support") def test_relu_graph(self): batch_size = 1 X = np.random.randn(batch_size, 1, 3, 2).astype(np.float32) @@ -36,3 +36,47 @@ def test_relu_graph(self): workspace.RunOperatorOnce(op) Y = workspace.FetchBlob("Y") np.testing.assert_almost_equal(Y, np.maximum(X, 0)) + + @unittest.skip("Need ONNXIFI backend support") + def test_conv_graph(self): + X = np.array([[[[0., 1., 2., 3., 4.], # (1, 1, 5, 5) input tensor + [5., 6., 7., 8., 9.], + [10., 11., 12., 13., 14.], + [15., 16., 17., 18., 19.], + [20., 21., 22., 23., 24.]]]]).astype(np.float32) + W = np.array([[[[1., 1., 1.], # (1, 1, 3, 3) tensor for convolution weights + [1., 1., 1.], + [1., 1., 1.]]]]).astype(np.float32) + Y_without_padding = np.array([[[[54., 63., 72.], # (1, 1, 3, 3) output tensor + [99., 108., 117.], + [144., 153., 162.]]]]).astype(np.float32) + graph_def = make_graph( + [make_node( + 'Conv', + inputs=['X', 'W'], + outputs=['Y'], + kernel_shape=[3, 3], + # Default values for other attributes: strides=[1, 1], dilations=[1, 1], groups=1 + pads=[0, 0, 0, 0], + )], + name="test", + inputs=[make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1, 1, 5, 5]), + make_tensor_value_info("W", onnx.TensorProto.FLOAT, [1, 1, 3, 3]), + ], + outputs=[make_tensor_value_info("Y", onnx.TensorProto.FLOAT, + [1, 1, 3, 3])]) + model_def = make_model(graph_def, producer_name='conv-test') + op = core.CreateOperator( + "Onnxifi", + ["X", "W"], + ["Y"], + onnx_model=model_def.SerializeToString(), + initializers=["W", "W"], + output_size_hint_0=[1, 1, 3, 3]) + workspace.FeedBlob("X", X) + workspace.FeedBlob("W", W) + workspace.RunOperatorOnce(op) + Y = workspace.FetchBlob("Y") + np.testing.assert_almost_equal(Y, Y_without_padding) + + diff --git a/caffe2/python/operator_test/bbox_transform_test.py b/caffe2/python/operator_test/bbox_transform_test.py index 20008acaa626c0..b54a4435513be7 100644 --- a/caffe2/python/operator_test/bbox_transform_test.py +++ b/caffe2/python/operator_test/bbox_transform_test.py @@ -152,6 +152,42 @@ def bbox_transform_rotated( return pred_boxes +def clip_tiled_boxes_rotated(boxes, im_shape, angle_thresh=1.0): + """ + Similar to clip_tiled_boxes but for rotated boxes with angle info. + Only clips almost horizontal boxes within angle_thresh. The rest are + left unchanged. + """ + assert ( + boxes.shape[1] % 5 == 0 + ), "boxes.shape[1] is {:d}, but must be divisible by 5.".format( + boxes.shape[1] + ) + + (H, W) = im_shape[:2] + + # Filter boxes that are almost upright within angle_thresh tolerance + idx = np.where(np.abs(boxes[:, 4::5]) <= angle_thresh) + idx5 = idx[1] * 5 + # convert to (x1, y1, x2, y2) + x1 = boxes[idx[0], idx5] - (boxes[idx[0], idx5 + 2] - 1) / 2.0 + y1 = boxes[idx[0], idx5 + 1] - (boxes[idx[0], idx5 + 3] - 1) / 2.0 + x2 = boxes[idx[0], idx5] + (boxes[idx[0], idx5 + 2] - 1) / 2.0 + y2 = boxes[idx[0], idx5 + 1] + (boxes[idx[0], idx5 + 3] - 1) / 2.0 + # clip + x1 = np.maximum(np.minimum(x1, W - 1), 0) + y1 = np.maximum(np.minimum(y1, H - 1), 0) + x2 = np.maximum(np.minimum(x2, W - 1), 0) + y2 = np.maximum(np.minimum(y2, H - 1), 0) + # convert back to (xc, yc, w, h) + boxes[idx[0], idx5] = (x1 + x2) / 2.0 + boxes[idx[0], idx5 + 1] = (y1 + y2) / 2.0 + boxes[idx[0], idx5 + 2] = x2 - x1 + 1 + boxes[idx[0], idx5 + 3] = y2 - y1 + 1 + + return boxes + + def generate_rois_rotated(roi_counts, im_dims): rois = generate_rois(roi_counts, im_dims) # [batch_id, ctr_x, ctr_y, w, h, angle] @@ -161,7 +197,7 @@ def generate_rois_rotated(roi_counts, im_dims): rotated_rois[:, 2] = (rois[:, 2] + rois[:, 4]) / 2. # ctr_y = (y1 + y2) / 2 rotated_rois[:, 3] = rois[:, 3] - rois[:, 1] + 1.0 # w = x2 - x1 + 1 rotated_rois[:, 4] = rois[:, 4] - rois[:, 2] + 1.0 # h = y2 - y1 + 1 - rotated_rois[:, 5] = np.random.uniform(0.0, 360.0) # angle in degrees + rotated_rois[:, 5] = np.random.uniform(-90.0, 90.0) # angle in degrees return rotated_rois @@ -173,6 +209,7 @@ class TestBBoxTransformOp(hu.HypothesisTestCase): skip_batch_id=st.booleans(), rotated=st.booleans(), angle_bound_on=st.booleans(), + clip_angle_thresh=st.sampled_from([-1.0, 1.0]), **hu.gcs_cpu_only ) def test_bbox_transform( @@ -183,6 +220,7 @@ def test_bbox_transform( skip_batch_id, rotated, angle_bound_on, + clip_angle_thresh, gc, dc, ): @@ -202,14 +240,16 @@ def test_bbox_transform( def bbox_transform_ref(rois, deltas, im_info): boxes = rois if rois.shape[1] == box_dim else rois[:, 1:] + im_shape = im_info[0, 0:2] if rotated: box_out = bbox_transform_rotated( boxes, deltas, angle_bound_on=angle_bound_on ) - # No clipping for rotated boxes + box_out = clip_tiled_boxes_rotated( + box_out, im_shape, angle_thresh=clip_angle_thresh + ) else: box_out = bbox_transform(boxes, deltas) - im_shape = im_info[0, 0:2] box_out = clip_tiled_boxes(box_out, im_shape) return [box_out] @@ -221,6 +261,7 @@ def bbox_transform_ref(rois, deltas, im_info): correct_transform_coords=True, rotated=rotated, angle_bound_on=angle_bound_on, + clip_angle_thresh=clip_angle_thresh, ) self.assertReferenceChecks( @@ -235,10 +276,18 @@ def bbox_transform_ref(rois, deltas, im_info): num_classes=st.integers(1, 10), rotated=st.booleans(), angle_bound_on=st.booleans(), + clip_angle_thresh=st.sampled_from([-1.0, 1.0]), **hu.gcs_cpu_only ) def test_bbox_transform_batch( - self, roi_counts, num_classes, rotated, angle_bound_on, gc, dc + self, + roi_counts, + num_classes, + rotated, + angle_bound_on, + clip_angle_thresh, + gc, + dc, ): """ Test with rois for multiple images in a batch @@ -266,14 +315,16 @@ def bbox_transform_ref(rois, deltas, im_info): continue cur_boxes = rois[offset : offset + num_rois, 1:] cur_deltas = deltas[offset : offset + num_rois] + im_shape = im_info[i, 0:2] if rotated: cur_box_out = bbox_transform_rotated( cur_boxes, cur_deltas, angle_bound_on=angle_bound_on ) - # No clipping for rotated boxes + cur_box_out = clip_tiled_boxes_rotated( + cur_box_out, im_shape, angle_thresh=clip_angle_thresh + ) else: cur_box_out = bbox_transform(cur_boxes, cur_deltas) - im_shape = im_info[i, 0:2] cur_box_out = clip_tiled_boxes(cur_box_out, im_shape) box_out.append(cur_box_out) offset += num_rois @@ -292,6 +343,7 @@ def bbox_transform_ref(rois, deltas, im_info): correct_transform_coords=True, rotated=rotated, angle_bound_on=angle_bound_on, + clip_angle_thresh=clip_angle_thresh, ) self.assertReferenceChecks( diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index c71b8d2bf86411..e58f5ba6be1ebe 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -1526,12 +1526,12 @@ void addGlobalMethods(py::module& m) { // into a python interface in transformations.py // Prefix the transformation with transform_ to avoid clobbering the // function namespace. - m.def("transform_optimizeForIDEEP", [](py::bytes def) { + m.def("transform_optimizeForIDEEP", [](py::bytes def, bool training_mode) { caffe2::NetDef proto; CAFFE_ENFORCE(ParseProtoFromLargeString(def.cast(), &proto)); auto nn = caffe2::convertToNNModule(proto); - opt::OptimizeForIdeep(&nn); + opt::OptimizeForIdeep(&nn, gWorkspace, training_mode); auto new_proto = caffe2::convertToCaffe2Proto(nn, proto); std::string out; diff --git a/caffe2/python/transformations.py b/caffe2/python/transformations.py index c1971e1f7c1c01..b48b73d4cbf9e5 100644 --- a/caffe2/python/transformations.py +++ b/caffe2/python/transformations.py @@ -52,9 +52,9 @@ def sinkMaxPool(net): ) -def optimizeForIDEEP(net): +def optimizeForIDEEP(net, training_mode = False): net.Proto().ParseFromString( - C.transform_optimizeForIDEEP(net.Proto().SerializeToString()) + C.transform_optimizeForIDEEP(net.Proto().SerializeToString(), training_mode) ) diff --git a/modules/observers/net_observer_reporter_print.cc b/modules/observers/net_observer_reporter_print.cc index 0982e3d7ff2b4f..2355fedc9a1a76 100644 --- a/modules/observers/net_observer_reporter_print.cc +++ b/modules/observers/net_observer_reporter_print.cc @@ -7,6 +7,8 @@ namespace caffe2 { const std::string NetObserverReporterPrint::IDENTIFIER = "Caffe2Observer "; +static std::string get_op_args(PerformanceInformation p); +static std::string get_tensor_shapes(PerformanceInformation p); void NetObserverReporterPrint::report( NetBase* net, @@ -27,6 +29,9 @@ void NetObserverReporterPrint::report( {"flops", {{"value", "-1"}, {"unit", "flops"}}}}; } else if (p.first != "NET_DELAY") { // for operator perf + std::string shape_str = get_tensor_shapes(p.second); + std::string args_str = get_op_args(p.second); + caffe2_perf[p.first] = { {"latency", {{"value", caffe2::to_string(p.second.latency * 1000)}, @@ -36,7 +41,9 @@ void NetObserverReporterPrint::report( "value", caffe2::to_string(p.second.flops), }, - {"unit", "flops"}}}}; + {"unit", "flops"}}}, + {"tensor_shapes", {{"info_string", shape_str}, {"unit", ""}}}, + {"op_args", {{"info_string", args_str}, {"unit", ""}}}}; } } @@ -67,4 +74,52 @@ void NetObserverReporterPrint::report( LOG(INFO) << buffer.str(); } } + +static std::string get_tensor_shapes(PerformanceInformation p) { + std::string shape_str; + std::stringstream shape_stream; + if (!p.tensor_shapes.empty()) { + shape_stream << "["; + for (int i = 0; i < p.tensor_shapes.size(); i++) { + shape_stream << "["; + for (int j = 0; j < p.tensor_shapes[i].dims_size(); j++) { + shape_stream << p.tensor_shapes[i].dims(j) << ", "; + } + shape_stream << "], "; + } + shape_stream << "]"; + shape_str = shape_stream.str(); + } else { + shape_str = "[]"; + } + return shape_str; +} + +static std::string get_op_args(PerformanceInformation p) { + std::string args_str; + if (!p.args.empty()) { + std::stringstream args; + args << "["; + for (int i = 0; i < p.args.size(); i++) { + args << "{" << p.args[i].name() << ": "; + if (p.args[i].has_i()) { + args << p.args[i].i(); + } else if (p.args[i].has_s()) { + args << p.args[i].s(); + } else if (p.args[i].has_n()) { + args << &p.args[i].n(); + } else if (p.args[i].has_f()) { + args << p.args[i].f(); + } else { + args << "None"; + } + args << "}, "; + } + args << "]"; + args_str = args.str(); + } else { + args_str = "[]"; + } + return args_str; +} } diff --git a/test/cpp/api/integration.cpp b/test/cpp/api/integration.cpp index 27fddda2fe6f10..57e0f14eb2bab3 100644 --- a/test/cpp/api/integration.cpp +++ b/test/cpp/api/integration.cpp @@ -14,6 +14,7 @@ #include using namespace torch::nn; +using namespace torch::test; #include #include @@ -236,7 +237,7 @@ TEST_CASE("integration/cartpole") { torch::manual_seed(0); std::cerr << "Training episodic policy gradient with a critic for up to 3000" " episodes, rest your eyes for a bit!\n"; - auto model = std::make_shared(); + auto model = std::make_shared(); auto linear = model->add(Linear(4, 128), "linear"); auto policyHead = model->add(Linear(128, 2), "policy"); auto valueHead = model->add(Linear(128, 1), "action"); @@ -333,7 +334,7 @@ TEST_CASE("integration/cartpole") { TEST_CASE("integration/mnist", "[cuda]") { torch::manual_seed(0); - auto model = std::make_shared(); + auto model = std::make_shared(); auto conv1 = model->add(Conv2d(1, 10, 5), "conv1"); auto conv2 = model->add(Conv2d(10, 20, 5), "conv2"); auto drop = Dropout(0.3); @@ -369,7 +370,7 @@ TEST_CASE("integration/mnist", "[cuda]") { TEST_CASE("integration/mnist/batchnorm", "[cuda]") { torch::manual_seed(0); - auto model = std::make_shared(); + auto model = std::make_shared(); auto conv1 = model->add(Conv2d(1, 10, 5), "conv1"); auto batchnorm2d = model->add(BatchNorm(BatchNormOptions(10).stateful(true)), "batchnorm2d"); diff --git a/test/cpp/api/module.cpp b/test/cpp/api/module.cpp index af790466d24f15..c8e7bdc605660d 100644 --- a/test/cpp/api/module.cpp +++ b/test/cpp/api/module.cpp @@ -6,7 +6,10 @@ #include #include +#include + using namespace torch::nn; +using namespace torch::test; using Catch::StartsWith; @@ -19,10 +22,6 @@ struct AGIUnit2 : torch::nn::Module { }; } // namespace test -bool pointer_equal(torch::Tensor first, torch::Tensor second) { - return first.data().data() == second.data().data(); -} - TEST_CASE("module/training-mode") { torch::manual_seed(0); Linear module(3, 4); @@ -195,6 +194,9 @@ TEST_CASE("module/clone") { SECTION("Cloning creates distinct parameters") { struct TestModule : public Cloneable { + TestModule() { + reset(); + } void reset() override { l1 = register_module("l1", Linear(10, 3)); l2 = register_module("l2", Linear(3, 5)); @@ -206,7 +208,7 @@ TEST_CASE("module/clone") { torch::Tensor buffer; }; - auto module = TestModule().build(); + auto module = std::make_shared(); auto module2 = module->clone(); auto params1 = module->parameters(); @@ -216,7 +218,7 @@ TEST_CASE("module/clone") { for (auto& param : params1) { REQUIRE(!pointer_equal(param.value, params2[param.key])); REQUIRE(param->allclose(params2[param.key])); - param->data().mul_(2); + param->data().add_(2); } for (auto& param : params1) { REQUIRE(!param->allclose(params2[param.key])); @@ -229,7 +231,7 @@ TEST_CASE("module/clone") { for (auto& buffer : buffers1) { REQUIRE(!pointer_equal(buffer.value, buffers2[buffer.key])); REQUIRE(buffer->allclose(buffers2[buffer.key])); - buffer->data().mul_(2); + buffer->data().add_(2); } for (auto& buffer : buffers1) { REQUIRE(!buffer->allclose(buffers2[buffer.key])); @@ -238,12 +240,15 @@ TEST_CASE("module/clone") { SECTION("Cloning preserves external references") { struct TestModule : public Cloneable { + TestModule() { + reset(); + } void reset() override { weight = register_parameter("weight", torch::ones({4, 4})); } torch::Tensor weight; }; - auto module = TestModule().build(); + auto module = std::make_shared(); module->weight.data() += 1; REQUIRE(pointer_equal(module->weight, module->parameters()["weight"])); REQUIRE(module->weight.allclose(module->parameters()["weight"])); @@ -259,6 +264,9 @@ TEST_CASE("module/clone") { SECTION("Cloning copies the values of variables of submodules") { struct TestModule : public Cloneable { + TestModule() { + reset(); + } void reset() override { weight = register_parameter("weight", torch::ones({4, 4})); } @@ -267,13 +275,16 @@ TEST_CASE("module/clone") { int value = 0; }; struct NestedModule : public Cloneable { + NestedModule() { + reset(); + } void reset() override { - module = register_module("module", TestModule().build()); + module = register_module("module", std::make_shared()); } std::shared_ptr module; }; - auto a = NestedModule().build(); + auto a = std::make_shared(); a->module->weight.data() += 1; a->module->value = 123; diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp index ac40ec20811858..2cacc75579097f 100644 --- a/test/cpp/api/modules.cpp +++ b/test/cpp/api/modules.cpp @@ -13,6 +13,7 @@ #include using namespace torch::nn; +using namespace torch::test; class TestModel : public torch::nn::Module { public: @@ -122,7 +123,7 @@ TEST_CASE("modules") { } SECTION("simple") { - auto model = std::make_shared(); + auto model = std::make_shared(); auto l1 = model->add(Linear(10, 3), "l1"); auto l2 = model->add(Linear(3, 5), "l2"); auto l3 = model->add(Linear(5, 100), "l3"); diff --git a/test/cpp/api/rnn.cpp b/test/cpp/api/rnn.cpp index 9423cb1ccfae49..a87d08802a7708 100644 --- a/test/cpp/api/rnn.cpp +++ b/test/cpp/api/rnn.cpp @@ -9,13 +9,14 @@ #include using namespace torch::nn; +using namespace torch::test; template bool test_RNN_xor(Func&& model_maker, bool cuda = false) { torch::manual_seed(0); auto nhid = 32; - auto model = std::make_shared(); + auto model = std::make_shared(); auto l1 = model->add(Linear(1, nhid), "l1"); auto rnn = model->add(model_maker(nhid), "rnn"); auto lo = model->add(Linear(nhid, 1), "lo"); diff --git a/test/cpp/api/sequential.cpp b/test/cpp/api/sequential.cpp index 0d608cd856481c..8aa608c4bb4787 100644 --- a/test/cpp/api/sequential.cpp +++ b/test/cpp/api/sequential.cpp @@ -9,7 +9,10 @@ #include #include +#include + using namespace torch::nn; +using namespace torch::test; using Catch::StartsWith; @@ -272,4 +275,31 @@ TEST_CASE("sequential") { return &first == &second; })); } + SECTION("Is cloneable") { + Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3)); + Sequential clone = + std::static_pointer_cast(sequential->clone()); + REQUIRE(sequential->size() == clone->size()); + + for (size_t i = 0; i < sequential->size(); ++i) { + // The modules should be the same kind (type). + REQUIRE(sequential[i]->name() == clone[i]->name()); + // But not pointer-equal (distinct objects). + REQUIRE(sequential[i] != clone[i]); + } + + // Verify that the clone is deep, i.e. parameters of modules are cloned too. + + auto params1 = sequential->parameters(); + auto params2 = clone->parameters(); + REQUIRE(params1.size() == params2.size()); + for (auto& param : params1) { + REQUIRE(!pointer_equal(param.value, params2[param.key])); + REQUIRE(param->allclose(params2[param.key])); + param->data().add_(2); + } + for (auto& param : params1) { + REQUIRE(!param->allclose(params2[param.key])); + } + } } diff --git a/test/cpp/api/util.h b/test/cpp/api/util.h index 32b62214151516..794e6424658bb7 100644 --- a/test/cpp/api/util.h +++ b/test/cpp/api/util.h @@ -6,6 +6,7 @@ #include namespace torch { +namespace test { // Lets you use a container without making a new class, // for experimental implementations @@ -20,4 +21,9 @@ class SimpleContainer : public nn::Cloneable { return Module::register_module(std::move(name), module_holder); } }; + +inline bool pointer_equal(torch::Tensor first, torch::Tensor second) { + return first.data().data() == second.data().data(); +} +} // namespace test } // namespace torch diff --git a/test/expect/TestCollectEnv.test_pytorch_linux_trusty_py27.expect b/test/expect/TestCollectEnv.test_pytorch_linux_trusty_py27.expect deleted file mode 100644 index cab5d177418f49..00000000000000 --- a/test/expect/TestCollectEnv.test_pytorch_linux_trusty_py27.expect +++ /dev/null @@ -1,19 +0,0 @@ -PyTorch version: 0.5.0a0 -Is debug build: No -CUDA used to build PyTorch: None - -OS: Ubuntu 14.04.X LTS -GCC version: (Ubuntu 4.8.4-2ubuntu1~14.04.4) 4.8.4 -CMake version: version 3.5.X - -Python version: 2.7 -Is CUDA available: No -CUDA runtime version: No CUDA -GPU models and configuration: No CUDA -Nvidia driver version: No CUDA -cuDNN version: No CUDA - -Versions of relevant libraries: -[pip] numpy (1.14.X) -[pip] torch (0.5.0a0) -[conda] Could not collect diff --git a/test/expect/TestCollectEnv.test_pytorch_linux_xenial_cuda9_cudnn7_py3.expect b/test/expect/TestCollectEnv.test_pytorch_linux_xenial_cuda9_cudnn7_py3.expect deleted file mode 100644 index 22eb2b6ebeefcf..00000000000000 --- a/test/expect/TestCollectEnv.test_pytorch_linux_xenial_cuda9_cudnn7_py3.expect +++ /dev/null @@ -1,25 +0,0 @@ -PyTorch version: 0.5.0a0 -Is debug build: No -CUDA used to build PyTorch: 9.0.X - -OS: Ubuntu 16.04.X LTS -GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.9) 5.4.0 20160609 -CMake version: version 3.11.X - -Python version: 3.6 -Is CUDA available: Yes -CUDA runtime version: 9.0.X -GPU models and configuration: -GPU 0: Tesla M60 -GPU 1: Tesla M60 - -Nvidia driver version: 396.X -cuDNN version: Probably one of the following: -/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.4 -/usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a - -Versions of relevant libraries: -[pip] numpy (1.14.X) -[pip] torch (0.5.0a0) -[conda] magma-cuda90 2.3.0 1 soumith -[conda] torch 0.5.0a0 diff --git a/test/expect/TestCollectEnv.test_pytorch_macos_1013_py3.expect b/test/expect/TestCollectEnv.test_pytorch_macos_1013_py3.expect deleted file mode 100644 index 09dd815124c0f2..00000000000000 --- a/test/expect/TestCollectEnv.test_pytorch_macos_1013_py3.expect +++ /dev/null @@ -1,19 +0,0 @@ -PyTorch version: 0.5.0a0 -Is debug build: No -CUDA used to build PyTorch: None - -OS: Mac OSX 10.13.X -GCC version: Could not collect -CMake version: version 3.11.X - -Python version: 3.6 -Is CUDA available: No -CUDA runtime version: No CUDA -GPU models and configuration: No CUDA -Nvidia driver version: No CUDA -cuDNN version: No CUDA - -Versions of relevant libraries: -[pip] numpy (1.14.X) -[pip] torch (0.5.0a0) -[conda] torch 0.5.0a0 diff --git a/test/expect/TestCollectEnv.test_pytorch_win_ws2016_cuda9_cudnn7_py3.expect b/test/expect/TestCollectEnv.test_pytorch_win_ws2016_cuda9_cudnn7_py3.expect deleted file mode 100644 index 6f68fc99f2ef27..00000000000000 --- a/test/expect/TestCollectEnv.test_pytorch_win_ws2016_cuda9_cudnn7_py3.expect +++ /dev/null @@ -1,19 +0,0 @@ -PyTorch version: 0.5.0a0 -Is debug build: No -CUDA used to build PyTorch: 9.0 - -OS: Microsoft Windows Server 2012 R2 Standard -GCC version: Could not collect -CMake version: version 3.10.X - -Python version: 3.6 -Is CUDA available: Yes -CUDA runtime version: 9.0.X -GPU models and configuration: GPU 0: Tesla M60 -Nvidia driver version: 390.X -cuDNN version: Probably one of the following: -C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0\bin\cudnn64_7.dll - -Versions of relevant libraries: -[pip] numpy (1.14.X) -[conda] Could not collect diff --git a/test/expect/TestJit.test_peephole.expect b/test/expect/TestJit.test_peephole.expect new file mode 100644 index 00000000000000..6acd3d9479e7de --- /dev/null +++ b/test/expect/TestJit.test_peephole.expect @@ -0,0 +1,4 @@ +graph(%0 : Double(1) + %1 : Double(1)) { + return (%0); +} diff --git a/test/expect/TestJit.test_peephole_cuda-different_device.expect b/test/expect/TestJit.test_peephole_cuda-different_device.expect new file mode 100644 index 00000000000000..6f399dac524458 --- /dev/null +++ b/test/expect/TestJit.test_peephole_cuda-different_device.expect @@ -0,0 +1,5 @@ +graph(%0 : Double(1) + %1 : Double(1)) { + %2 : Double(1) = aten::type_as(%0, %1) + return (%2); +} diff --git a/test/expect/TestJit.test_peephole_cuda-same_device.expect b/test/expect/TestJit.test_peephole_cuda-same_device.expect new file mode 100644 index 00000000000000..6acd3d9479e7de --- /dev/null +++ b/test/expect/TestJit.test_peephole_cuda-same_device.expect @@ -0,0 +1,4 @@ +graph(%0 : Double(1) + %1 : Double(1)) { + return (%0); +} diff --git a/test/test_autograd.py b/test/test_autograd.py index 3b519ea5a86874..1a734cf23a1b55 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -330,6 +330,17 @@ def test_grad_badcalls(self): y = x ** 2 torch.autograd.grad(y, x) # this should succeed now + def test_grad_fn_badcalls(self): + error_regex = 'expected .* arguments, got .* instead' + x = torch.ones(1, requires_grad=True) + y = x ** 2 + with self.assertRaisesRegex(TypeError, error_regex): + y.grad_fn(x.detach(), x.detach()) # too many + with self.assertRaisesRegex(TypeError, error_regex): + y.grad_fn() # too few + + y.grad_fn(x.detach()) # this should succeed + def test_grad_unreachable(self): x = torch.ones(1, requires_grad=True) y = torch.ones(1, requires_grad=True) diff --git a/test/test_jit.py b/test/test_jit.py index 0663b41b67e08f..765c19dbfeb788 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -28,7 +28,6 @@ from torch.jit.frontend import NotSupportedError from torch.jit import BatchTensor -import torch.jit.batchop try: import torchvision @@ -256,6 +255,48 @@ def f(x, y): self.assertExpectedGraph(trace) self.assertExportImport(trace, (x, y)) + def test_peephole(self): + a = torch.tensor([0.4], requires_grad=True) + b = torch.tensor([0.7], requires_grad=True) + c = torch.tensor([0], dtype=torch.int32) + + def f(x, y): + return x.type_as(y) + + trace, z = torch.jit.get_trace_graph(f, (a, b)) + self.run_pass('peephole', trace) + self.assertExpectedGraph(trace) + trace, z = torch.jit.get_trace_graph(f, (a, c)) + s = str(trace) + self.run_pass('peephole', trace) + self.assertEqual(s, str(trace)) + + def test_peephole_dynamic(self): + def f(x, y): + return x.type_as(y) + + fn = torch.jit.script(f) + s = str(fn.graph) + torch._C._jit_pass_peephole(fn.graph) + self.assertEqual(s, str(fn.graph)) + + @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA") + def test_peephole_cuda(self): + a = torch.tensor([0.4], requires_grad=True, device='cpu') + b = torch.tensor([0.7], requires_grad=True, device='cuda') + c = torch.tensor([0.7], requires_grad=True, device='cuda') + + def f(x, y): + return x.type_as(y) + + trace, z = torch.jit.get_trace_graph(f, (a, c)) + s = str(trace) + self.run_pass('peephole', trace) + self.assertEqual(s, str(trace)) + trace, z = torch.jit.get_trace_graph(f, (b, c)) + self.run_pass('peephole', trace) + self.assertExpectedGraph(trace, subname="same_device") + def test_index(self): x = torch.tensor([0.4], requires_grad=True) y = torch.tensor([0], dtype=torch.int64) diff --git a/test/test_multiprocessing.py b/test/test_multiprocessing.py index 099f4d774ccc60..e88a3578f2d9d2 100644 --- a/test/test_multiprocessing.py +++ b/test/test_multiprocessing.py @@ -207,18 +207,15 @@ def test_receive(): def _test_preserve_sharing(self, ctx=mp, repeat=1): def do_test(): x = torch.randn(5, 5) - data = [x.storage(), x.storage()[1:4], x, x[2], x[:, 1]] + data = [x.storage(), x, x[2], x[:, 1]] q = ctx.Queue() q.put(data) new_data = q.get(timeout=1) self.assertEqual(new_data, data, 0) storage_cdata = data[0]._cdata self.assertEqual(new_data[0]._cdata, storage_cdata) - for t in new_data[2:]: + for t in new_data[1:]: self.assertEqual(t.storage()._cdata, storage_cdata) - # TODO: enable after fixing #46 - # new_data[0].fill_(10) - # self.assertEqual(new_data[1], new_data[0][1:4], 0) with leak_checker(self): for i in range(repeat): @@ -335,7 +332,13 @@ def test_cuda_small_tensors(self): self.assertEqual(v, torch.arange(i * 5., (i + 1) * 5).sum()) self.assertEqual(device, i % 2) self.assertEqual(tensor_size, 5) - self.assertEqual(storage_size, 5) + # You might think this should be the case, but it's not! After + # data from the CUDA caching allocator goes through IPC, the + # size of the storage is the size of the *cached cudaMalloc for + # the entire memory block* of the storage, not just the storage. + # See Note [CUDA IPC and the caching allocator] for more info + # + # self.assertEqual(storage_size, 5) @unittest.skipIf(IS_WINDOWS, 'not applicable to Windows (only fails with fork)') @unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available') diff --git a/test/test_torch.py b/test/test_torch.py index f51d9cd405a5a1..4a015829c389a5 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1072,7 +1072,7 @@ def _test_neg(self, cast): if t in float_types: a = cast(torch.randn(100, 90).type(t)) else: - a = cast(torch.Tensor(100, 90).type(t).random_()) + a = cast(torch.Tensor(100, 90).type(t).random_(-128, 128)) zeros = cast(torch.Tensor().type(t)).resize_as_(a).zero_() if t == 'torch.ByteTensor': @@ -6456,17 +6456,6 @@ def test_storage(self): self.assertEqual(v.storage()[0], v.data[0][0]) self.assertEqual(v.storage()[14], v.data[2][4]) - def test_storageview(self): - s1 = torch.LongStorage((3, 4, 5)) - s2 = torch.LongStorage(s1, 1) - - self.assertEqual(s2.size(), 2) - self.assertEqual(s2[0], s1[1]) - self.assertEqual(s2[1], s1[2]) - - s2[1] = 13 - self.assertEqual(13, s1[2]) - def test_nonzero(self): num_src = 12 @@ -6732,14 +6721,14 @@ def test_parsing_intlist(self): def _test_serialization_data(self): a = [torch.randn(5, 5).float() for i in range(2)] - b = [a[i % 2] for i in range(4)] - b += [a[0].storage()] - b += [a[0].storage()[1:4]] - b += [torch.arange(1, 11).int()] - t1 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,)) - t2 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,)) - b += [(t1.storage(), t1.storage(), t2.storage())] - b += [a[0].storage()[0:2]] + b = [a[i % 2] for i in range(4)] # 0-3 + b += [a[0].storage()] # 4 + b += [a[0].reshape(-1)[1:4].storage()] # 5 + b += [torch.arange(1, 11).int()] # 6 + t1 = torch.FloatTensor().set_(a[0].reshape(-1)[1:4].clone().storage(), 0, (3,), (1,)) + t2 = torch.FloatTensor().set_(a[0].reshape(-1)[1:4].clone().storage(), 0, (3,), (1,)) + b += [(t1.storage(), t1.storage(), t2.storage())] # 7 + b += [a[0].reshape(-1)[0:2].storage()] # 8 return b def _test_serialization_assert(self, b, c): @@ -6754,7 +6743,10 @@ def _test_serialization_assert(self, b, c): self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0) c[1].fill_(20) self.assertEqual(c[1], c[3], 0) - self.assertEqual(c[4][1:4], c[5], 0) + # I have to do it in this roundabout fashion, because there's no + # way to slice storages + for i in range(4): + self.assertEqual(c[4][i + 1], c[5][i]) # check that serializing the same storage view object unpickles # it as one object not two (and vice versa) @@ -6914,7 +6906,7 @@ def test_serialization_backwards_compat(self): a = [torch.arange(1 + i, 26 + i).view(5, 5).float() for i in range(2)] b = [a[i % 2] for i in range(4)] b += [a[0].storage()] - b += [a[0].storage()[1:4]] + b += [a[0].reshape(-1)[1:4].clone().storage()] path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt') c = torch.load(path) self.assertEqual(b, c, 0) @@ -6928,7 +6920,6 @@ def test_serialization_backwards_compat(self): self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0) c[1].fill_(20) self.assertEqual(c[1], c[3], 0) - self.assertEqual(c[4][1:4], c[5], 0) # test some old tensor serialization mechanism class OldTensorBase(object): diff --git a/test/test_utils.py b/test/test_utils.py index e8c33ca761c7e8..077f789a0a142a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -630,64 +630,10 @@ def test_bottleneck_cuda(self): class TestCollectEnv(TestCase): - - def _build_env_to_expect(self, build_env): - return 'expect/TestCollectEnv.test_{}.expect'.format( - build_env.replace('.', '').replace('-', '_')) - - def _preprocess_info_for_test(self, info_output): - # Remove the version hash - version_hash_regex = re.compile(r'(a\d+)\+\w+') - result = re.sub(version_hash_regex, r'\1', info_output).strip() - - # Substitutions to lower the specificity of the versions listed - substitutions = [ - (r'(?<=CUDA used to build PyTorch: )(\d+)\.(\d+)\.(\d+)', r'\1.\2.X'), - (r'(?<=CUDA runtime version: )(\d+)\.(\d+)\.(\d+)', r'\1.\2.X'), - (r'(?<=Ubuntu )(\d+)\.(\d+)\.(\d+) ', r'\1.\2.X '), - (r'(?<=CMake version: version )(\d+)\.(\d+)\.(\d+)', r'\1.\2.X'), - (r'(?<=Nvidia driver version: )(\d+)\.(\d+)', r'\1.X'), - (r'(?<=Mac OSX )(\d+)\.(\d+).(\d+)', r'\1.\2.X'), - (r'(?<=numpy \()(\d+)\.(\d+).(\d+)', r'\1.\2.X'), - ] - - for regex, substitute in substitutions: - result = re.sub(regex, substitute, result) - return result - - def assertExpectedOutput(self, info_output, build_env): - processed_info = self._preprocess_info_for_test(info_output) - expect_filename = self._build_env_to_expect(build_env) - - ci_warning = ('This test will error out if the CI config was recently ' - 'updated. If this is the case, please update the expect ' - 'files to match the CI machines\' system config.') - - with open(expect_filename, 'r') as f: - expected_info = f.read().strip() - self.assertEqual(ci_warning + '\n' + processed_info, - ci_warning + '\n' + expected_info, ci_warning) - def test_smoke(self): info_output = get_pretty_env_info() self.assertTrue(info_output.count('\n') >= 17) - @unittest.skipIf('BUILD_ENVIRONMENT' not in os.environ.keys(), 'CI-only test') - def test_expect(self): - info_output = get_pretty_env_info() - - ci_build_envs = [ - 'pytorch-linux-trusty-py2.7', - 'pytorch-linux-xenial-cuda9-cudnn7-py3', - 'pytorch-macos-10.13-py3', - 'pytorch-win-ws2016-cuda9-cudnn7-py3' - ] - build_env = os.environ['BUILD_ENVIRONMENT'] - if build_env not in ci_build_envs: - return - - self.assertExpectedOutput(info_output, build_env) - class TestONNXUtils(TestCase): def test_prepare_onnx_paddings(self): diff --git a/third_party/onnx b/third_party/onnx index b4072194c2e6ef..b2817a682f25f9 160000 --- a/third_party/onnx +++ b/third_party/onnx @@ -1 +1 @@ -Subproject commit b4072194c2e6ef90693bcfdea4c6f45cf30bb65e +Subproject commit b2817a682f25f960586f06caa539bbbd7a96b859 diff --git a/third_party/onnx-tensorrt b/third_party/onnx-tensorrt index 82106f833dcb00..fa0964e8477fc0 160000 --- a/third_party/onnx-tensorrt +++ b/third_party/onnx-tensorrt @@ -1 +1 @@ -Subproject commit 82106f833dcb0070446a150e658e60ca9428f89b +Subproject commit fa0964e8477fc004ee2f49ee77ffce0bf7f711a9 diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py index 6fc454ca12c074..d6458f9c2337e2 100644 --- a/tools/jit/gen_jit_dispatch.py +++ b/tools/jit/gen_jit_dispatch.py @@ -88,9 +88,10 @@ def attr_of(jit_type): # map from aten 'simple_type' to the function that will turn a tensor into # that type FROM_TENSOR = { - 'Device': 'tensor_as', + 'Device': 'tensor_as>', 'ScalarType': 'tensor_as', 'Layout': 'tensor_as', + 'IntList': 'tensor_as>', } @@ -107,7 +108,7 @@ def from_tensor(arg): """) POS_ASSIGNMENT = CodeTemplate("""\ -auto ${name} = ${from_tensor}(std::move(peek(stack, ${i}, ${N})));\ +auto ${name} = ${from_tensor}(std::move(peek(stack, ${i}, ${N})).toTensor());\ """) CALL_NAMESPACE = CodeTemplate("""\ @@ -261,12 +262,12 @@ def emit_decl_variant(decl, is_positional_arg, has_tensorlist): # NOTE: don't advance real_inputs here. After this we are going # to switch over to indexing from the end as if we only had # the static arguments. - arguments.append('peekSlice(stack, {}, varargs_length - {}, varargs_length)' + arguments.append('toTensors(peekSlice(stack, {}, varargs_length - {}, varargs_length))' .format(real_inputs, static_inputs)) elif arg['simple_type'] in default_only_types: arguments.append(arg['default']) elif is_tensor_arg(arg): - arguments.append('std::move(peek(stack, {}, {}))'.format(real_inputs, view_length)) + arguments.append('std::move(peek(stack, {}, {})).toTensor()'.format(real_inputs, view_length)) real_inputs += 1 elif is_positional_arg[i]: template_kwargs = dict(from_tensor=from_tensor(arg), diff --git a/tools/jit/templates/register_aten_ops.cpp b/tools/jit/templates/register_aten_ops.cpp index 4cb7fbaaaaeae2..2f4d0558fe4cb5 100644 --- a/tools/jit/templates/register_aten_ops.cpp +++ b/tools/jit/templates/register_aten_ops.cpp @@ -29,7 +29,6 @@ using autograd::Variable; using autograd::variable_list; using at::Scalar; using at::Tensor; -using at::IntList; using at::TensorList; using at::TensorOptions; using at::DeviceGuard; @@ -39,10 +38,16 @@ namespace { int deviceForInputs(Stack & stack, size_t N) { if(N == 0) return -1; - auto & t = *(stack.end() - N); + auto t = (stack.end() - N)->toTensor(); return t.type().is_cuda() ? (int) t.get_device() : -1; } +std::vector toTensors(at::ArrayRef ivalues) { + return fmap(ivalues, [](const IValue& v) { + return v.toTensor(); + }); +} + template std::array as_bool_array(const std::vector& vec) { std::array res; diff --git a/torch/csrc/PtrWrapper.cpp b/torch/csrc/PtrWrapper.cpp index 544895bfd0a44d..52d7b1f14065df 100644 --- a/torch/csrc/PtrWrapper.cpp +++ b/torch/csrc/PtrWrapper.cpp @@ -45,8 +45,7 @@ static PyObject * THPWrapper_pynew(PyTypeObject *type, PyObject *args, PyObject return self; } -// UBSAN error: https://github.com/pytorch/pytorch/issues/9054 -static void THPWrapper_dealloc(THPWrapper* self) __ubsan_ignore_function__ +static void THPWrapper_dealloc(THPWrapper* self) { self->destructor(self->data); Py_TYPE(self)->tp_free((PyObject*)self); diff --git a/torch/csrc/api/include/torch/nn/cloneable.h b/torch/csrc/api/include/torch/nn/cloneable.h index 43b3b23e822542..61a32e20fe8061 100644 --- a/torch/csrc/api/include/torch/nn/cloneable.h +++ b/torch/csrc/api/include/torch/nn/cloneable.h @@ -26,13 +26,6 @@ class Cloneable : public Module { /// semantics, most importantly parameters, buffers and submodules. virtual void reset() = 0; - /// Moves the `Module` into a `shared_ptr` and calls `reset()` on it. - std::shared_ptr build() { - auto module = std::make_shared(static_cast(*this)); - module->reset(); - return module; - } - /// Performs a recursive "deep copy" of the `Module`, such that all parameters /// and submodules in the cloned module are different from those in the /// original module. @@ -43,12 +36,30 @@ class Cloneable : public Module { copy->buffers_.clear(); copy->children_.clear(); copy->reset(); + AT_CHECK( + copy->parameters_.size() == parameters_.size(), + "The cloned module does not have the same number of " + "parameters as the original module after calling reset(). " + "Are you sure you called register_parameter() inside reset() " + "and not the constructor?"); for (const auto& parameter : parameters_) { copy->parameters_[parameter.key].data().copy_(parameter->data()); } + AT_CHECK( + copy->buffers_.size() == buffers_.size(), + "The cloned module does not have the same number of " + "buffers as the original module after calling reset(). " + "Are you sure you called register_buffer() inside reset() " + "and not the constructor?"); for (const auto& buffer : buffers_) { copy->buffers_[buffer.key].data().copy_(buffer->data()); } + AT_CHECK( + copy->children_.size() == children_.size(), + "The cloned module does not have the same number of " + "child modules as the original module after calling reset(). " + "Are you sure you called register_module() inside reset() " + "and not the constructor?"); for (const auto& child : children_) { copy->children_[child.key]->clone_(*child.value); } diff --git a/torch/csrc/api/include/torch/nn/modules/any.h b/torch/csrc/api/include/torch/nn/modules/any.h index 121a8afe0ff925..be5fd3a6702826 100644 --- a/torch/csrc/api/include/torch/nn/modules/any.h +++ b/torch/csrc/api/include/torch/nn/modules/any.h @@ -48,10 +48,14 @@ class AnyModule { AnyModule(AnyModule&&) = default; AnyModule& operator=(AnyModule&&) = default; - /// Creates a copy of an `AnyModule`. + /// Creates a shallow copy of an `AnyModule`. AnyModule(const AnyModule& other); AnyModule& operator=(const AnyModule& other); + /// Creates a deep copy of an `AnyModule` if it contains a module, else an + /// empty `AnyModule` if it is empty. + AnyModule clone() const; + /// Assigns a module to the `AnyModule` (to circumvent the explicit /// constructor). template @@ -238,7 +242,10 @@ struct AnyModule::Placeholder : public AnyModule::Value::Placeholder { /// Returns std::shared_ptr pointing to the erased module. virtual std::shared_ptr ptr() = 0; - /// Returns a `Placeholder` with a copy of this `AnyModule`. + /// Returns a `Placeholder` with a shallow copy of this `AnyModule`. + virtual std::unique_ptr copy() const = 0; + + /// Returns a `Placeholder` with a deep copy of this `AnyModule`. virtual std::unique_ptr clone() const = 0; }; @@ -297,10 +304,15 @@ struct AnyModule::Holder : public AnyModule::Placeholder { return module; } - std::unique_ptr clone() const override { + std::unique_ptr copy() const override { return torch::make_unique(*this); } + std::unique_ptr clone() const override { + return torch::make_unique( + std::static_pointer_cast(module->clone())); + } + /// The actual concrete module instance. std::shared_ptr module; }; @@ -323,15 +335,21 @@ AnyModule::AnyModule(const ModuleHolder& module_holder) : AnyModule(module_holder.ptr()) {} inline AnyModule::AnyModule(const AnyModule& other) - : content_(other.content_ ? other.content_->clone() : nullptr) {} + : content_(other.content_ ? other.content_->copy() : nullptr) {} inline AnyModule& AnyModule::operator=(const AnyModule& other) { if (this != &other) { - content_ = other.content_ ? other.content_->clone() : nullptr; + content_ = other.content_ ? other.content_->copy() : nullptr; } return *this; } +inline AnyModule AnyModule::clone() const { + AnyModule clone; + clone.content_ = content_ ? content_->clone() : nullptr; + return clone; +} + template AnyModule& AnyModule::operator=(std::shared_ptr module) { return (*this = AnyModule(std::move(module))); diff --git a/torch/csrc/api/include/torch/nn/modules/sequential.h b/torch/csrc/api/include/torch/nn/modules/sequential.h index 755e712a9f1a21..1c28656692e7f2 100644 --- a/torch/csrc/api/include/torch/nn/modules/sequential.h +++ b/torch/csrc/api/include/torch/nn/modules/sequential.h @@ -35,8 +35,18 @@ class SequentialImpl : public Cloneable { push_back(std::forward(modules)...); } - /// reset() is empty for `Sequential`, since it does not have parameter of its - /// own. + /// Special cloning function for `Sequential` because it does not use + /// `reset()`. + std::shared_ptr clone() const override { + auto clone = std::make_shared(); + for (const auto& module : modules_) { + clone->push_back(module.clone()); + } + return clone; + } + + /// `reset()` is empty for `Sequential`, since it does not have parameter of + /// its own. void reset() override {} /// Feeds the `inputs` to the first module, then chains the output of each diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index a271febf5a3466..16e8105090ecfd 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -136,6 +136,14 @@ void Variable::Impl::set_data(Tensor new_data) { data_ = std::move(new_data); } +void Variable::Impl::release_resources() { + data_.reset(); + grad_.reset(); + grad_fn_.reset(); + hooks_.clear(); + tracing_state_.reset(); +} + Variable::ViewImpl::ViewImpl(Variable base, at::Tensor data, Edge gradient_edge) : Variable::Impl(std::move(data), false, std::move(gradient_edge)), base_(std::move(base)) { @@ -182,6 +190,11 @@ void Variable::ViewImpl::rebase_history(Edge gradient_edge) { get_grad_fn(); // trigger an update to the view's grad_fn } +void Variable::ViewImpl::release_resources() { + Variable::Impl::release_resources(); + base_.reset(); +} + void Variable::rebase_history(Edge gradient_edge) { TORCH_ASSERT(gradient_edge.function != nullptr); if (is_view()) { @@ -200,4 +213,5 @@ void Variable::set_tracing_state( jit::tracer::ValueTracingState& Variable::tracing_state() const noexcept { return *get()->tracing_state_; } + }} // namespace torch::autograd diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index 02e2c043bc0d0c..a6d670ae55f703 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -346,6 +346,9 @@ struct Variable::Impl : public at::TensorImpl { bool keep_graph, bool create_graph) override; + /// Reset all expensive fields to free up resources + void release_resources() override; + // Make this field public so we can access it from `Variable`. using at::TensorImpl::type_; @@ -400,6 +403,9 @@ struct Variable::ViewImpl : public Variable::Impl { return base_; } + /// Reset all expensive fields to free up resources + void release_resources() override; + /// Called after in-place modifications. Modifies the grad_fn of the base /// Variable. void rebase_history(Edge gradient_edge); diff --git a/torch/csrc/generic/Storage.cpp b/torch/csrc/generic/Storage.cpp index 4795c4112272b4..3169515784a69d 100644 --- a/torch/csrc/generic/Storage.cpp +++ b/torch/csrc/generic/Storage.cpp @@ -88,44 +88,8 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec // torch.Storage(view_source, [offset, [size]]) if (num_args < 4 && THPStorage_(Check)(first_arg)) { -#ifdef THD_GENERIC_FILE - THPUtils_setError("distributed storages don't support storage views"); + THPUtils_setError("storage views not supported"); return NULL; -#else - THPStorage *storage_arg = (THPStorage *)first_arg; - int64_t numel = storage_arg->cdata->size; - int64_t offset = 0; - - if (num_args >= 2) { - PyObject *second_arg = PyTuple_GET_ITEM(args, 1); - if (!THPUtils_checkLong(second_arg)) - goto invalid_arguments; - offset = THPUtils_unpackLong(second_arg); - } - - int64_t size = numel - offset; - if (num_args >= 3) { - PyObject *third_arg = PyTuple_GET_ITEM(args, 2); - if (!THPUtils_checkLong(third_arg)) - goto invalid_arguments; - size = THPUtils_unpackLong(third_arg); - } - - THPUtils_assert(offset >= 0 && offset <= numel, "specified an offset of " - "%" PRId64 ", but the viewed storage has only %" PRId64 " element(s)", offset, numel); - THPUtils_assert(size >= 1 && size <= numel - offset, "specified a size of " - "%" PRId64 ", but the viewed storage has only %" PRId64 " element(s) after offset %" PRId64, - size, numel - offset, offset); - - real *data_ptr = THWStorage_(data)(LIBRARY_STATE storage_arg->cdata) + offset; - // TODO: Hmmmm - THWStoragePtr storage(THWStorage_(newWithDataAndAllocator)(LIBRARY_STATE {data_ptr, storage_arg->cdata->data_ptr.device()} /* non-owning */, size, nullptr)); - storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_VIEW; - storage->view = storage_arg->cdata; - THWStorage_(retain)(LIBRARY_STATE storage_arg->cdata); - self->cdata = storage.release(); - return (PyObject*)self.release(); -#endif } // torch.Storage(sequence) @@ -161,9 +125,6 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec #endif } -#ifndef THD_GENERIC_FILE -invalid_arguments: -#endif THPUtils_invalidArguments(args, kwargs, THPStorageStr " constructor", 6, "no arguments", "(int size)", @@ -199,30 +160,8 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index) return THPUtils_(newReal)(value); /* Slice index */ } else if (PySlice_Check(index)) { -#ifdef THD_GENERIC_FILE - THPUtils_setError("distributed storages don't support slicing"); + THPUtils_setError("storages don't support slicing"); return NULL; -#else - Py_ssize_t start, stop, slicelength, step; - int64_t len = THWStorage_(size)(LIBRARY_STATE self->cdata); - if (!THPUtils_parseSlice(index, len, &start, &stop, &step, &slicelength)) - return NULL; - if (step != 1) { - THPUtils_setError("Trying to slice with a step of %" PRId64 ", but only a step of " - "1 is supported", (int64_t)step); - return NULL; - } - - real *data = THWStorage_(data)(LIBRARY_STATE self->cdata); - THWStoragePtr new_storage(THWStorage_(newWithDataAndAllocator)(LIBRARY_STATE {static_cast(data + start), self->cdata->data_ptr.device()} /* non-owning */, slicelength, nullptr)); - new_storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_VIEW; - new_storage->view = self->cdata; - THWStorage_(retain)(LIBRARY_STATE self->cdata); - - PyObject *_ret = THPStorage_(New)(new_storage); - new_storage.release(); - return _ret; -#endif } PyErr_Format(PyExc_TypeError, "can't index a " THPStorageStr " with %s", THPUtils_typename(index)); diff --git a/torch/csrc/generic/StorageMethods.cpp b/torch/csrc/generic/StorageMethods.cpp index 5ea91375bcb48a..f4e6cd0fd00ee4 100644 --- a/torch/csrc/generic/StorageMethods.cpp +++ b/torch/csrc/generic/StorageMethods.cpp @@ -292,26 +292,6 @@ PyObject * THPStorage_(_setCdata)(THPStorage *self, PyObject *new_cdata) END_HANDLE_TH_ERRORS } -#ifndef THD_GENERIC_FILE -PyObject * THPStorage_(_rootStorage)(THPStorage *self) -{ - HANDLE_TH_ERRORS - if (!(self->cdata->flag & TH_STORAGE_VIEW)) { - return Py_BuildValue("(ON)", self, PyLong_FromLong(0)); - } - THWStorage *root = self->cdata; - while (root->flag & TH_STORAGE_VIEW) - root = root->view; - size_t offset = THWStorage_(data)(LIBRARY_STATE self->cdata) - THWStorage_(data)(LIBRARY_STATE root); - THWStorage_(retain)(LIBRARY_STATE root); - THPObjectPtr storage(THPStorage_(New)(root)); - PyObject *result = Py_BuildValue("(NN)", storage.get(), PyLong_FromLong(offset)); - storage.release(); - return result; - END_HANDLE_TH_ERRORS -} -#endif - static PyMethodDef THPStorage_(methods)[] = { {"copy_", (PyCFunction)THPStorage_(copy_), METH_VARARGS | METH_KEYWORDS, NULL}, {"element_size", (PyCFunction)THPStorage_(elementSize), METH_NOARGS, NULL}, @@ -335,7 +315,6 @@ static PyMethodDef THPStorage_(methods)[] = { #endif {"_set_cdata", (PyCFunction)THPStorage_(_setCdata), METH_O, NULL}, #ifndef THD_GENERIC_FILE - {"_root_storage", (PyCFunction)THPStorage_(_rootStorage), METH_NOARGS, NULL}, #endif {NULL} }; diff --git a/torch/csrc/generic/StorageSharing.cpp b/torch/csrc/generic/StorageSharing.cpp index b9b002d1ff334e..c68dfbfa049b0c 100644 --- a/torch/csrc/generic/StorageSharing.cpp +++ b/torch/csrc/generic/StorageSharing.cpp @@ -34,16 +34,6 @@ static PyObject * THPStorage_(sharedIncref)(THPStorage *self) END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(newTHView)(THWStorage *base, ptrdiff_t offset, size_t size) -{ - void *data = (char*)base->data() + offset; - THWStoragePtr view(THWStorage_(newWithDataAndAllocator)(LIBRARY_STATE {data, base->data_ptr.device()} /* non-owning */, size, nullptr)); - view->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_VIEW; - view->view = base; - THWStorage_(retain)(LIBRARY_STATE base); - return THPStorage_(New)(view.release()); -} - #ifndef THC_GENERIC_FILE // TODO: move this somewhere - we only need one version static std::string THPStorage_(__newHandle)() { @@ -226,13 +216,12 @@ static PyObject * THPStorage_(shareCuda)(THPStorage *self) HANDLE_TH_ERRORS THWStorage *storage = self->cdata; at::DeviceGuard device_guard(storage->data_ptr.device().index()); - THPObjectPtr tuple(PyTuple_New(5)); + THPObjectPtr tuple(PyTuple_New(4)); THPObjectPtr device(PyLong_FromLong(storage->data_ptr.device().index())); THPObjectPtr _handle(Py_None); Py_INCREF(Py_None); THPObjectPtr size(PyLong_FromLong(storage->size)); THPObjectPtr _offset(PyLong_FromLong(0)); - THPObjectPtr view_size(PyLong_FromLong(storage->size)); if (THWStorage_(data)(LIBRARY_STATE storage)) { size_t base_size; void *base_ptr = THCCachingAllocator_getBaseAllocation(THWStorage_(data)(LIBRARY_STATE storage), &base_size); @@ -242,17 +231,16 @@ static PyObject * THPStorage_(shareCuda)(THPStorage *self) THCudaCheck(cudaIpcGetMemHandle(&handle, base_ptr)); _handle = PyBytes_FromStringAndSize((char *)&handle, CUDA_IPC_HANDLE_SIZE); - _offset = PyLong_FromSsize_t((Py_ssize_t)offset); + _offset = PyLong_FromSsize_t((Py_ssize_t)offset / sizeof(real)); size = PyLong_FromSize_t(base_size / sizeof(real)); } - if (!tuple || !device || !_handle || !size || !_offset || !view_size) { + if (!tuple || !device || !_handle || !size || !_offset) { return NULL; } PyTuple_SET_ITEM(tuple.get(), 0, device.release()); PyTuple_SET_ITEM(tuple.get(), 1, _handle.release()); PyTuple_SET_ITEM(tuple.get(), 2, size.release()); PyTuple_SET_ITEM(tuple.get(), 3, _offset.release()); - PyTuple_SET_ITEM(tuple.get(), 4, view_size.release()); return tuple.release(); END_HANDLE_TH_ERRORS } @@ -260,23 +248,18 @@ static PyObject * THPStorage_(shareCuda)(THPStorage *self) static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args) { HANDLE_TH_ERRORS - THPUtils_assert(PyTuple_GET_SIZE(args) == 5, "tuple of 5 items expected"); + THPUtils_assert(PyTuple_GET_SIZE(args) == 3, "tuple of 3 items expected"); PyObject *_device = PyTuple_GET_ITEM(args, 0); PyObject *_handle = PyTuple_GET_ITEM(args, 1); PyObject *_size = PyTuple_GET_ITEM(args, 2); - PyObject *_offset = PyTuple_GET_ITEM(args, 3); - PyObject *_view_size = PyTuple_GET_ITEM(args, 4); if (!(THPUtils_checkLong(_device) && THPUtils_checkLong(_size) - && (_handle == Py_None || PyBytes_Check(_handle)) - && THPUtils_checkLong(_offset) && THPUtils_checkLong(_view_size))) { + && (_handle == Py_None || PyBytes_Check(_handle)))) { THPUtils_invalidArguments(args, NULL, "_new_shared in CUDA mode", 1, - "(int device, bytes handle, int storage_size, int offset, int view_size"); + "(int device, bytes handle, int storage_size)"); return NULL; } size_t storage_size = (size_t)THPUtils_unpackLong(_size); - ptrdiff_t offset = (ptrdiff_t)THPUtils_unpackLong(_offset); - size_t view_size = (size_t)THPUtils_unpackLong(_view_size); int64_t device = THPUtils_unpackLong(_device); at::DeviceGuard device_guard(device); @@ -296,11 +279,7 @@ static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args) LIBRARY_STATE THCIpcDeleter::makeDataPtr(devPtr, device), storage_size, /* allocator */ nullptr)); - base->flag = TH_STORAGE_REFCOUNTED; - - if (offset != 0 || view_size != storage_size) { - return THPStorage_(newTHView)(base.get(), offset, view_size); - } + base->flag = TH_STORAGE_REFCOUNTED; // NB: Not resizable return THPStorage_(New)(base.release()); END_HANDLE_TH_ERRORS @@ -315,9 +294,6 @@ static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args) static PyObject * THPStorage_(weakRef)(THPStorage *self, PyObject *weak_ref_class) { HANDLE_TH_ERRORS THStorage* storage = self->cdata; - while (storage->flag & TH_STORAGE_VIEW) { - storage = storage->view; - } THStorage_weakRetain(storage); @@ -387,20 +363,6 @@ PyObject * THPStorage_(sharedFd)(THPStorage *self) END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(newView)(THPStorage *self, PyObject *args) -{ - HANDLE_TH_ERRORS - if (PyTuple_Size(args) != 2 || !THPUtils_checkLong(PyTuple_GET_ITEM(args, 0)) - || ! THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) { - THPUtils_invalidArguments(args, NULL, "_new_view", 1, "(int offset, int size)"); - return NULL; - } - int64_t offset = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0)); - int64_t size = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1)); - return THPStorage_(newTHView)(self->cdata, offset, size); - END_HANDLE_TH_ERRORS -} - PyObject * THPStorage_(isShared)(THPStorage *self) { #ifdef THC_GENERIC_FILE @@ -430,7 +392,6 @@ static PyMethodDef THPStorage_(sharingMethods)[] = { #endif {"_weak_ref", (PyCFunction)THPStorage_(weakRef), METH_O, NULL}, {"_free_weak_ref", (PyCFunction)THPStorage_(freeWeakRef), METH_O | METH_STATIC, NULL}, - {"_new_view", (PyCFunction)THPStorage_(newView), METH_VARARGS, NULL}, {"_shared_decref", (PyCFunction)THPStorage_(sharedDecref), METH_NOARGS, NULL}, {"_shared_incref", (PyCFunction)THPStorage_(sharedIncref), METH_NOARGS, NULL}, {"_get_shared_fd", (PyCFunction)THPStorage_(sharedFd), METH_NOARGS, NULL}, diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index da1167739a9d63..fdeb0ef13a8c36 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -17,7 +17,7 @@ bool hasOneValuedAttribute(Node *n, torch::jit::Symbol name) { bool isDifferentiable(Node * n) { static std::unordered_set differentiable_kinds = { - aten::add, aten::sub, aten::mul, prim::Constant, prim::ReplaceIfUndef, + aten::add, aten::sub, aten::mul, prim::Constant, aten::sigmoid, aten::tanh, aten::mm, aten::chunk, aten::split, aten::t, aten::neg, aten::unsqueeze, aten::expand, aten::addmm, aten::gt, aten::lt, aten::eq, aten::ne, aten::ge, aten::le, aten::type_as, aten::relu, aten::exp, prim::AutogradAdd @@ -99,8 +99,6 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val return {grads.at(0) * inputs.at(1), grads.at(0) * inputs.at(0)}; case prim::Constant: return {}; - case prim::ReplaceIfUndef: - return {grads.at(0), grads.at(0)}; case aten::sigmoid: return {grads.at(0) * outputs.at(0) * (1 - outputs.at(0))}; case aten::tanh: diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index 65c6a7086b9fd9..5ef60d95a47dc5 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -21,6 +21,7 @@ #include "torch/csrc/jit/passes/loop_unrolling.h" #include "torch/csrc/jit/passes/lower_grad_of.h" #include "torch/csrc/jit/symbolic_variable.h" +#include "torch/csrc/jit/ivalue.h" #include "torch/csrc/autograd/edge.h" #include "torch/csrc/autograd/function.h" @@ -72,6 +73,16 @@ struct ExecutionPlanAutogradFunction : public autograd::Function { }; +// helper to run interpreter on variables until we switch +// everything to IValue +inline variable_tensor_list runOneStage(const Code & code, variable_tensor_list inputs) { + std::vector stack(inputs.begin(), inputs.end()); + InterpreterState(code).runOneStage(stack); + return variable_tensor_list(fmap(stack, [](IValue& v) { + return std::move(v).toTensor(); + })); +} + // an optimized way of executing the subgraph computed directly on // tensors rather than Variables. // This will unwrap Variables, run the plan, and re-wrap them. @@ -90,8 +101,7 @@ struct ExecutionPlan { if(grad) { return runWithGrad(std::move(stack)); } - InterpreterState(f).runOneStage(stack); - return stack; + return runOneStage(f, std::move(stack)); } std::shared_ptr get_graph() const { return graph; @@ -113,14 +123,15 @@ struct ExecutionPlan { } private: - // inplace to avoid allocations - variable_tensor_list unwrapVariables(variable_tensor_list && list) const { - for(auto & v : list) { - v = v.defined() ? autograd::as_variable_ref(v).detach() : at::Tensor(); - } - return std::move(list); + // note: should be inplace to avoid allocations, but we have to switch from + // a list of tensor to a list of ivalues + std::vector unwrapVariables(variable_tensor_list && list) const { + return fmap(list, [](const Variable& v) -> IValue { + return v.defined() ? autograd::as_variable_ref(v).detach() : at::Tensor(); + }); } - // inplace to avoid allocations + // note: should be inplace to avoid allocations, but we have to switch from + // a list of tensor to a list of ivalues variable_tensor_list wrapTensors(tensor_list && list) const { for(auto & v : list) { v = autograd::make_variable(v, /*requires_grad=*/false); @@ -152,7 +163,8 @@ struct ExecutionPlan { auto stack = unwrapVariables(std::move(inputs)); InterpreterState(f).runOneStage(stack); - variable_tensor_list outputs = std::move(stack); + variable_tensor_list outputs( + fmap(stack, [](IValue& v) { return std::move(v).toTensor(); })); // hookup the gradients for the output tensors that require gradients // to the inputs to our gradient function df @@ -311,11 +323,7 @@ struct GraphExecutorImpl { variable_tensor_list runFallback(variable_tensor_list inputs) { auto & fb = getOrCreateAutogradFallback(); - InterpreterState state(fb); - auto stack = std::move(inputs); - state.runOneStage(stack); - // note: we never unwrapped inputs, because we want autograd to record the trace - return stack; + return runOneStage(fb, std::move(inputs)); } static bool calcMayIntroduceGradient(Block* b) { diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h index a4a73eb8f24846..b61a49be846dc7 100644 --- a/torch/csrc/jit/interned_strings.h +++ b/torch/csrc/jit/interned_strings.h @@ -35,7 +35,6 @@ _(prim, PadPacked) /* onnx */ \ _(prim, Placeholder) /* debug */ \ _(prim, Print) \ _(prim, PythonOp) \ -_(prim, ReplaceIfUndef) \ _(prim, Reverse) \ _(prim, Return) \ _(prim, Store) \ diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index 1fb82c9035952e..1dd6ea6c5877cc 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -9,6 +9,7 @@ #include "torch/csrc/jit/graph_executor.h" #include "torch/csrc/jit/ir.h" #include "torch/csrc/jit/tensor_conversions.h" +#include "torch/csrc/jit/ivalue.h" #include "torch/csrc/variable_tensor_functions.h" #include "torch/csrc/autograd/generated/variable_factories.h" @@ -410,7 +411,7 @@ struct CodeImpl { JIT_ASSERT(inst.debug_name == prim::Placeholder); auto offset = relativeJump(from_inst, to_inst); inst.callback = [offset](Stack & stack) { - auto t = tensor_as(pop(stack)); + auto t = tensor_as(pop(stack).toTensor()); return (t == 0) ? offset : 0; }; inst.debug_name = prim::JumpZ; @@ -422,7 +423,7 @@ struct CodeImpl { JIT_ASSERT(inst.debug_name == prim::Placeholder); auto offset = relativeJump(from_inst, to_inst); inst.callback = [offset](Stack & stack) { - auto t = tensor_as(pop(stack)); + auto t = tensor_as(pop(stack).toTensor()); return (t != 0) ? offset : 0; }; inst.debug_name = prim::JumpNZ; @@ -629,7 +630,8 @@ struct CodeImpl { return [=](Stack& stack) mutable { autograd::profiler::RecordFunction record("GraphExecutor"); auto inputs = last(stack, num_inputs); - variable_tensor_list tinputs(inputs.begin(), inputs.end()); + variable_tensor_list tinputs( + fmap(inputs, [](const IValue& v) { return v.toTensor(); })); drop(stack, num_inputs); //TODO: has graph executor work from a stack as well variable_tensor_list toutputs = executor->run(variable_tensor_list(std::move(tinputs))); @@ -774,7 +776,7 @@ struct InterpreterStateImpl { // in the case where it is true, then the interpreter and this array get copied // if this every becomes a bottleneck then we _should_ consider minimizing the // total number or register - std::vector registers; + std::vector registers; // single buffer for input/output calls to ATen functions, so that we do not reallocate Stack stack; @@ -799,7 +801,7 @@ InterpreterState::InterpreterState(const Code & function) InterpreterState::~InterpreterState() {} void InterpreterState::runOneStage(Stack & stack) { - return pImpl->runOneStage(stack); + return pImpl->runOneStage(stack); } const TensorType & InterpreterState::tensorTypeForInput(size_t i) const { diff --git a/torch/csrc/jit/interpreter.h b/torch/csrc/jit/interpreter.h index ed086bd05f881f..b9085528fbb44d 100644 --- a/torch/csrc/jit/interpreter.h +++ b/torch/csrc/jit/interpreter.h @@ -19,6 +19,8 @@ struct InterpreterStateImpl; struct Graph; struct Node; struct TensorType; +struct IValue; +using Stack = std::vector; struct Code { Code() @@ -44,7 +46,7 @@ struct InterpreterState { // advance the interpreter state by running one stage. Returning the // outputs for that stage, suspending the computation. // Call this function again continues computation where it left off. - void runOneStage(std::vector & stack); + void runOneStage(Stack & stack); const TensorType & tensorTypeForInput(size_t i) const; ~InterpreterState(); // create a copy of InterpreterState with its current state diff --git a/torch/csrc/jit/ivalue.h b/torch/csrc/jit/ivalue.h new file mode 100644 index 00000000000000..c31436de5ab105 --- /dev/null +++ b/torch/csrc/jit/ivalue.h @@ -0,0 +1,278 @@ +#pragma once +#include +#include "torch/csrc/assertions.h" + +namespace torch { namespace jit { + +// smart pointer to hold onto at::Retainable objects in a generic way +// this is close to the implementation of boost's intrusive_ptr +template +struct Shared { + Shared(): Shared(nullptr, false) {} + Shared(PointerType * self, bool retain) + : pImpl(self) { + if(retain && pImpl) + pImpl->retain(); + } + Shared(const Shared & rhs) + : pImpl(rhs.pImpl) { + if (pImpl) + pImpl->retain(); + } + Shared(Shared && rhs) noexcept + : pImpl(rhs.pImpl) { + rhs.pImpl = nullptr; + } + ~Shared() { + if (pImpl) + pImpl->release(); + } + Shared & operator=(Shared && rhs) & { + rhs.swap(*this); + return *this; + } + Shared & operator=(Shared const & rhs) & { + //Shared ctor retains original rhs.pImpl + //then rhs.pImpl is swapped with this->pImpl + //finally Shared dtor releases rhs.pImpl, which was originally this->pImpl + Shared(rhs).swap(*this); + return *this; + } + void reset() { + Shared().swap(*this); + } + void reset(PointerType * rhs) { + Shared(rhs, true).swap(*this); + } + void reset(PointerType * rhs, bool retain) { + Shared(rhs, retain).swap(*this); + } + void swap(Shared & rhs) { + PointerType * tmp = pImpl; + pImpl = rhs.pImpl; + rhs.pImpl = tmp; + } + PointerType* get() const { + return pImpl; + } + PointerType* detach() { + PointerType * ret = pImpl; + pImpl = nullptr; + return ret; + } + PointerType& operator*() const { + return *get(); + } + PointerType* operator->() const { + return get(); + } + operator bool() const { + return pImpl != nullptr; + } +private: + PointerType * pImpl; +}; + + +template +struct ConstantList; +struct IValue; +using Tuple = ConstantList; +using IntList = ConstantList; +using DoubleList = ConstantList; + +// IValue is the generic tagged union used by the interpreter to hold +// all value types. +// It is a 16-byte object with an 8-byte payload and an 8-byte tag. +// The tag is currently 4 bytes to determine the type, and 1 byte +// to mark whether that type is a subtype of at::Retainable and needs +// retain/release calls. +struct IValue { + IValue() + : payload(0) + , tag(Tag::None) + , retainable(false) {} + IValue(const IValue& rhs) + : payload(rhs.payload), + tag(rhs.tag), + retainable(rhs.retainable) { + if (retainable) + as_retainable->retain(); + } + IValue(IValue&& rhs) noexcept : IValue() { + swap(rhs); + } + ~IValue() { + if (retainable) { + as_retainable->release(); + } + } + IValue & operator=(IValue && rhs) & { + rhs.swap(*this); + return *this; + } + IValue & operator=(IValue const & rhs) & { + IValue(rhs).swap(*this); + return *this; + } + void swap(IValue & rhs) { + std::swap(payload, rhs.payload); + std::swap(retainable, rhs.retainable); + std::swap(tag, rhs.tag); + } + // Accessors for subtypes are arragned together below + // While some of these accessors could be generated through templates, + // we prefer to write them manually for clarity + + // Tensor + IValue(at::Tensor t) + : tag(Tag::Tensor), retainable(t.defined()) { + // note: the undefined tensor is not refcounted, so while it + // is tagged as a tensor, retainable is set to false. + as_tensor_impl = t.at::detail::TensorBase::detach(); + } + bool isTensor() const { return Tag::Tensor == tag; } + at::Tensor toTensor() && { + JIT_ASSERT(isTensor()); + at::Tensor t(as_tensor_impl, /*retain=*/false); + clearToNone(); + return t; + } + at::Tensor toTensor() const & { + JIT_ASSERT(isTensor()); + return at::Tensor(as_tensor_impl, /*retain=*/true); + } + + // Tuple + IValue(Shared v); + bool isTuple() const { return Tag::Tuple == tag; } + Shared toTuple() && { + JIT_ASSERT(isTuple()); + return moveToRetainable(); + } + Shared toTuple() const & { + JIT_ASSERT(isTuple()); + return toRetainable(); + } + + // Double + IValue(double d) + : tag(Tag::Double), retainable(false) { + as_double = d; + } + bool isDouble() const { return Tag::Double == tag; } + double toDouble() const { + JIT_ASSERT(isDouble()); + return as_double; + } + + // Int + IValue(int64_t i) + : tag(Tag::Int), retainable(false) { + as_int = i; + } + // allow you to pass literals (3, 4) without ambiguity + IValue(int32_t i) + : IValue(static_cast(i)) {} + + bool isInt() const { return Tag::Int == tag; } + int64_t toInt() const { + JIT_ASSERT(isInt()); + return as_int; + } + + // IntList + IValue(Shared v); + bool isIntList() const { return Tag::IntList == tag; } + Shared toIntList() && { + JIT_ASSERT(isIntList()); + return moveToRetainable(); + } + Shared toIntList() const & { + JIT_ASSERT(isIntList()); + return toRetainable(); + } + + // DoubleList + IValue(Shared v); + bool isDoubleList() const { return Tag::DoubleList == tag; } + Shared toDoubleList() && { + JIT_ASSERT(isDoubleList()); + return moveToRetainable(); + } + Shared toDoubleList() const & { + JIT_ASSERT(isDoubleList()); + return toRetainable(); + } + + bool isNone() { + return Tag::None == tag; + } + +private: + template + Shared moveToRetainable() { + Shared t(static_cast(as_retainable), false); + clearToNone(); + return t; + } + template + Shared toRetainable() const { + return Shared(static_cast(as_retainable), true); + } + void clearToNone() { + payload = 0; + tag = Tag::None; + retainable = false; + } + enum class Tag : uint32_t { + None, Tensor, Double, Int, Tuple, IntList, DoubleList + }; + union { + at::TensorImpl* as_tensor_impl; + at::Retainable* as_retainable; + double as_double; + int64_t as_int; + // this type should be as big as all the other types because it will + // be used to copy the union's value in certain cases + int64_t payload; + }; + Tag tag; + bool retainable; +}; + + +// non-mutable list +template +struct ConstantList : at::Retainable { + private: + ConstantList(std::vector elements_) + : elements_(std::move(elements_)) {} + std::vector elements_; + public: + static Shared> create(std::vector elements_) { + return Shared>( + new ConstantList(std::move(elements_)), false); + } + at::ArrayRef elements() const { + return elements_; + } +}; + +inline IValue::IValue(Shared v) +: tag(Tag::Tuple), retainable(true) { + as_retainable = v.detach(); +} + +inline IValue::IValue(Shared v) +: tag(Tag::IntList), retainable(true) { + as_retainable = v.detach(); +} + +inline IValue::IValue(Shared v) +: tag(Tag::DoubleList), retainable(true) { + as_retainable = v.detach(); +} + + +}} diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index 5619a5ad08cbfb..4d997bb8017a08 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -41,6 +41,21 @@ void PeepholeOptimize(Block * block) { // Let DCE clean up any unused nodes at this point } } break; + case aten::type_as: { + JIT_ASSERT(n->inputs().size() == 2); + Value *lhs = n->input(0); + Value *rhs = n->input(1); + // If LHS and RHS have the same static type, remove the type_as operator. + if (lhs->type()->kind() == TypeKind::TensorType && + rhs->type()->kind() == TypeKind::TensorType) { + auto ltype = (*lhs->type()).cast(); + auto rtype = (*rhs->type()).cast(); + if(ltype->device() == rtype->device() && + ltype->scalarType() == rtype->scalarType()) { + n->output()->replaceAllUsesWith(lhs); + } + } + } break; // Fuse mm + add into addmm case aten::add: { // Must have two inputs diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 7e4b45e986eeb2..f8239c5a6457c7 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -88,7 +88,7 @@ void broadcastPointwise(Node *node, std::vector& types) { void PropagateShapeOnNodeByRunningIt(Node* node, const std::vector& types) { auto op = getOperation(node); - std::vector stack; + Stack stack; for(auto & type : types) { stack.push_back(representativeTensor(type)); @@ -102,7 +102,7 @@ void PropagateShapeOnNodeByRunningIt(Node* node, const std::vector& JIT_ASSERT(stack.size() == node->outputs().size()); for(size_t i = 0; i < stack.size(); ++i) { - node->outputs()[i]->inferTypeFrom(stack[i]); + node->outputs()[i]->inferTypeFrom(stack[i].toTensor()); } } @@ -322,14 +322,6 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) { node->output()->setType(ten->withSizes(sizes)); } } break; - case prim::ReplaceIfUndef: { - // If types[0] has a type, then it is not defined, and the type will - // get set to types[0] because that will be the value propagated. - // If its type is not defined, then unification is an undefined type. - SHAPE_ASSERT(types.size() == 1); - node->output()->setType(types.at(0)->shared_from_this()); - handled = true; - } break; case prim::Constant: { node->output()->inferTypeFrom(node->t(attr::value)); handled = true; diff --git a/torch/csrc/jit/python_interpreter.cpp b/torch/csrc/jit/python_interpreter.cpp index c0668b7a6e2bd3..5af53c4455b12f 100644 --- a/torch/csrc/jit/python_interpreter.cpp +++ b/torch/csrc/jit/python_interpreter.cpp @@ -44,7 +44,7 @@ Operation createPythonOperation(Node* op_) { py_inputs[i] = py::reinterpret_borrow( op->scalar_args[next_scalar++].get()); } else if (arg_type == 't') { - auto var = peek(stack, next_tensor, num_inputs); + auto var = std::move(peek(stack, next_tensor, num_inputs)).toTensor(); py_inputs[i] = py::reinterpret_steal(THPVariable_Wrap(var)); next_tensor++; diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 0d084edefa52d4..3a2ae20850de80 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -42,7 +42,10 @@ RegisterOperators reg({ autograd::profiler::RecordFunction record("FusionGroup"); std::vector toutputs; // TODO: have fusion_fn work off of a stack as well - fusion_fn->launch(last(stack, num_inputs), toutputs); + auto tinputs = fmap(last(stack, num_inputs), [](const IValue& v) { + return v.toTensor(); + }); + fusion_fn->launch(tinputs, toutputs); drop(stack, num_inputs); stack.insert(stack.end(), toutputs.begin(), toutputs.end()); return 0; @@ -69,28 +72,14 @@ RegisterOperators reg({ return 0; }; }), - Operator( - prim::ReplaceIfUndef, - [](Node* n) { - return [](Stack& stack) { - auto alternate = pop(stack); - auto result = pop(stack); - if (result.defined()) { - stack.push_back(std::move(result)); - } else { - stack.push_back(std::move(alternate)); - } - return 0; - }; - }), - Operator( prim::Print, [](Node* node) { size_t num_inputs = node->inputs().size(); return [num_inputs](Stack& stack) { bool first = true; - for (at::Tensor i : last(stack, num_inputs)) { + for (const IValue& i_ : last(stack, num_inputs)) { + auto i = i_.toTensor(); if (!first) std::cout << " "; first = false; @@ -114,7 +103,7 @@ RegisterOperators reg({ // and inst.outputs Operator(prim::Load, noop), // x, y = Store - // stores values from stack into registers, the actual callback does + // stores vales from stack into registers, the actual callback does // nothing since the stack manipulation is already encoded in inst.inputs // and inst.outputs Operator(prim::Store, noop), @@ -132,8 +121,8 @@ RegisterOperators reg({ onnx::Reshape, [](Node* node) { return [=](Stack& stack) { - auto shape = pop(stack).contiguous(); - auto input = pop(stack); + auto shape = pop(stack).toTensor().contiguous(); + auto input = pop(stack).toTensor(); JIT_ASSERT(shape.ndimension() == 1); at::IntList shape_list(shape.data(), shape.size(0)); stack.push_back(input.reshape(shape_list)); @@ -144,7 +133,7 @@ RegisterOperators reg({ onnx::Shape, [](Node* node) { return [=](Stack& stack) { - auto t = pop(stack); + auto t = pop(stack).toTensor(); at::IntList sizes = t.sizes(); auto sizes_tensor = torch::empty( {static_cast(sizes.size())}, at::dtype(at::kLong)); @@ -165,8 +154,8 @@ RegisterOperators reg({ auto false_ = at::full({}, 0, at::kLong); return [=](Stack& stack) { bool result = false; - for (const at::Tensor& t : last(stack, num_inputs)) { - if (t.defined()) { + for (const IValue& t : last(stack, num_inputs)) { + if (std::move(t).toTensor().defined()) { result = true; break; } @@ -181,8 +170,8 @@ RegisterOperators reg({ prim::AutogradAdd, [](Node* node) { return [=](Stack& stack) { - auto a = pop(stack); - auto b = pop(stack); + auto a = pop(stack).toTensor(); + auto b = pop(stack).toTensor(); if (!a.defined()) stack.push_back(b); else if (!b.defined()) diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 8d1d3f754515fd..df3ff8151c6c1d 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -297,9 +297,9 @@ static bool isNumberSubtype(const TypePtr& type) { } at::optional> getIntListAttribute(at::optional N, Value* input) { - auto list = constant_as(input); + auto list = constant_as>(input); if(list) - return std::vector(*list); + return list; // broadcast IntList[3] with value 4 -> {4, 4, 4} if(!N) return at::nullopt; diff --git a/torch/csrc/jit/stack.h b/torch/csrc/jit/stack.h index 503725396f086e..e4d1d185db5be4 100644 --- a/torch/csrc/jit/stack.h +++ b/torch/csrc/jit/stack.h @@ -1,10 +1,11 @@ #pragma once #include "ATen/ATen.h" #include "torch/csrc/jit/tensor_conversions.h" +#include "torch/csrc/jit/ivalue.h" namespace torch { namespace jit { -using Stack = std::vector; +using Stack = std::vector; using Operation = std::function; // An operation with N inputs and M outputs pops the last N inputs off @@ -21,21 +22,21 @@ using Operation = std::function; // treat the last N elements of the stack as a list, looking up // element i -static inline at::Tensor & peek(Stack & stack, size_t i, size_t N) { +static inline IValue & peek(Stack & stack, size_t i, size_t N) { return *(stack.end() - N + i); } // treat the last N elements of the stack as a list, looking up the // slice starting at index i and having length len -static inline at::ArrayRef peekSlice(Stack & stack, size_t i, size_t len, size_t N) { - return at::ArrayRef(stack).slice(stack.size() - N + i, len); +static inline at::ArrayRef peekSlice(Stack & stack, size_t i, size_t len, size_t N) { + return at::ArrayRef(stack).slice(stack.size() - N + i, len); } -static inline at::ArrayRef last(Stack & stack, size_t N) { +static inline at::ArrayRef last(Stack & stack, size_t N) { return peekSlice(stack, 0, N, N); } static inline void drop(Stack & stack, size_t n) { stack.erase(stack.end() - n, stack.end()); } -static inline at::Tensor pop(Stack & stack) { +static inline IValue pop(Stack & stack) { auto r = std::move(stack.back()); stack.pop_back(); return r; @@ -47,22 +48,22 @@ static inline at::Tensor pop(Stack & stack) { // pack takes the return values of aten functions pushes them onto the stack template inline void pack(Stack & stack, T&& v) { - stack.push_back(as_variable(std::move(v))); + stack.push_back(IValue(as_variable(std::move(v)))); } template<> inline void pack(Stack & stack, at::Tensor&& v) { - stack.push_back(std::move(v)); + stack.push_back(IValue(std::move(v))); } template<> inline void pack(Stack & stack, autograd::Variable&& v) { - stack.push_back(std::move(v)); + stack.push_back(IValue(std::move(v))); } template<> inline void pack(Stack & stack, std::vector&& ts) { for(auto& t : ts) { - stack.push_back(std::move(t)); + stack.push_back(IValue(std::move(t))); } } diff --git a/torch/csrc/jit/tensor_conversions.h b/torch/csrc/jit/tensor_conversions.h index 276db961495dd5..84162a445a910f 100644 --- a/torch/csrc/jit/tensor_conversions.h +++ b/torch/csrc/jit/tensor_conversions.h @@ -57,15 +57,15 @@ struct tensor_as_impl> { }; template<> -struct tensor_as_impl { - at::IntList operator()(at::Tensor&& t) { +struct tensor_as_impl> { + std::vector operator()(at::Tensor&& t) { if (t.type().scalarType() != at::ScalarType::Long) throw tensor_conversion_error("Expected a LongTensor"); if (t.dim() != 1) throw tensor_conversion_error("Expected a 1D LongTensor"); if (!t.is_contiguous()) throw tensor_conversion_error("Expected a contiguous LongTensor"); - return at::IntList{t.data(), static_cast(t.numel())}; + return std::vector(t.data(), t.data() + t.numel()); } }; diff --git a/torch/csrc/jit/test_jit.cpp b/torch/csrc/jit/test_jit.cpp index 54e99f98e4648e..7b784f093330ba 100644 --- a/torch/csrc/jit/test_jit.cpp +++ b/torch/csrc/jit/test_jit.cpp @@ -36,6 +36,8 @@ #include "torch/csrc/jit/graph_executor.h" #include "torch/csrc/jit/script/compiler.h" #include "torch/csrc/jit/script/module.h" +#include "torch/csrc/jit/ivalue.h" + #include "onnx/onnx_pb.h" @@ -439,8 +441,12 @@ std::shared_ptr build_lstm_stages() { } void runOneStage(InterpreterState & interp, const std::vector & inputs, std::vector & outputs) { - outputs = inputs; - interp.runOneStage(outputs); + std::vector stack(inputs.begin(), inputs.end()); + interp.runOneStage(stack); + outputs.clear(); + for(auto & ivalue : stack) { + outputs.push_back(std::move(ivalue).toTensor()); + } } void interpTest() { @@ -878,7 +884,7 @@ const static auto cf_examples = R"JIT( void testControlFlow() { script::Module cu; script::defineMethodsInModule(cu, cf_examples, torch::jit::script::Resolver(), nullptr); - auto run = [&](const std::string & name, std::vector stack) { + auto run = [&](const std::string & name, std::vector stack) { auto graph = cu.get_method(name).graph(); Code code(graph); InterpreterState interp(code); @@ -886,8 +892,8 @@ void testControlFlow() { return stack; }; - auto L = [](int64_t l) { return autograd::make_variable(at::Scalar(l).toTensor()); }; - auto V = [](at::Tensor t) { return at::Scalar(t).toLong(); }; + auto L = [](int64_t l) { return IValue(autograd::make_variable(at::Scalar(l).toTensor())); }; + auto V = [](IValue t) { return at::Scalar(std::move(t).toTensor()).toLong(); }; auto run_binary = [&](const std::string & name, int64_t a, int64_t b) { return V(run(name, {L(a), L(b)})[0]); }; @@ -898,6 +904,50 @@ void testControlFlow() { REQUIRE(256 == run_binary("while_test",2,0)); } +void testIValue() { + Shared foo = IntList::create({3, 4, 5}); + JIT_ASSERT(foo->use_count() == 1); + IValue bar(foo); + JIT_ASSERT(foo->use_count() == 2); + auto baz = bar; + JIT_ASSERT(foo->use_count() == 3); + auto foo2 = std::move(bar); + JIT_ASSERT(foo->use_count() == 3); + JIT_ASSERT(foo2.isIntList()); + JIT_ASSERT(bar.isNone()); + foo2 = IValue(4.0); + JIT_ASSERT(foo2.isDouble()); + JIT_ASSERT(foo2.toDouble() == 4.0); + JIT_ASSERT(foo->use_count() == 2); + JIT_ASSERT(baz.toIntList()->elements().equals({3,4,5})); + + auto move_it = std::move(baz).toIntList(); + JIT_ASSERT(foo->use_count() == 2); + JIT_ASSERT(baz.isNone()); + IValue i(4); + JIT_ASSERT(i.isInt() && i.toInt() == 4); + IValue dlist(DoubleList::create({3.5})); + JIT_ASSERT( + dlist.isDoubleList() && + std::move(dlist).toDoubleList()->elements().equals({3.5})); + JIT_ASSERT(dlist.isNone()); + dlist = IValue(DoubleList::create({3.4})); + JIT_ASSERT(dlist.toDoubleList()->elements().equals({3.4})); + IValue the_list(Tuple::create({IValue(3.4), IValue(4), IValue(foo)})); + JIT_ASSERT(foo->use_count() == 3); + JIT_ASSERT(the_list.isTuple()); + auto first = std::move(the_list).toTuple()->elements().at(1); + JIT_ASSERT(first.toInt() == 4); + at::Tensor tv = at::rand({3,4}); + IValue ten(tv); + JIT_ASSERT(tv.get()->use_count() == 2); + auto ten2 = ten; + JIT_ASSERT(tv.get()->use_count() == 3); + JIT_ASSERT(ten2.toTensor().equal(ten.toTensor())); + std::move(ten2).toTensor(); + JIT_ASSERT(tv.get()->use_count() == 2); +} + void testProto() { ::ONNX_NAMESPACE::ModelProto proto; proto.set_producer_name("foo"); @@ -905,6 +955,7 @@ void testProto() { std::string runJITCPPTests() { std::stringstream out; + testIValue(); testControlFlow(); testGraphExecutor(); testBlocks(out); diff --git a/torch/csrc/utils/functional.h b/torch/csrc/utils/functional.h index 3f81228ce8369c..af5099e7ce4e84 100644 --- a/torch/csrc/utils/functional.h +++ b/torch/csrc/utils/functional.h @@ -23,6 +23,15 @@ inline auto fmap(const T& inputs, const F& fn) -> std::vector +inline auto fmap(T& inputs, const F& fn) -> std::vector { + std::vector r; + r.reserve(inputs.size()); + for(auto & input : inputs) + r.push_back(fn(input)); + return r; +} + // C++ forbids taking an address of a constructor, so here's a workaround... // Overload for constructor (R) application template diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index a7d45a7720a417..fbf3fabbcfc113 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -390,6 +390,7 @@ def script_method(fn): def batch(batch_size=1, optimize=True, _frames_up=0): def decorator(fn): + import torch.jit.batchop mod = script(fn, optimize, _frames_up) res_graph = torch.to_batch_graph(mod.graph) res_mod = ScriptModule() diff --git a/torch/lib/THD/base/DataChannelRequest.cpp b/torch/lib/THD/base/DataChannelRequest.cpp index 13540e4337c7cd..a9536b4a44ac3e 100644 --- a/torch/lib/THD/base/DataChannelRequest.cpp +++ b/torch/lib/THD/base/DataChannelRequest.cpp @@ -1,6 +1,6 @@ #include "DataChannelRequest.hpp" -THD_API void THDRequest_free(THDRequest* request) { - delete request; +THD_API void THDRequest_free(void* request) { + delete (THDRequest*)request; } diff --git a/torch/lib/THD/base/DataChannelRequest.h b/torch/lib/THD/base/DataChannelRequest.h index 8112c7188241b4..3a1d70dea22e5e 100644 --- a/torch/lib/THD/base/DataChannelRequest.h +++ b/torch/lib/THD/base/DataChannelRequest.h @@ -7,4 +7,4 @@ struct _THDRequest; typedef struct _THDRequest THDRequest; #endif -THD_API void THDRequest_free(THDRequest* req); +THD_API void THDRequest_free(void* req); diff --git a/torch/lib/THD/master_worker/master/generic/THDStorage.cpp b/torch/lib/THD/master_worker/master/generic/THDStorage.cpp index 353224463e3e97..82c3b1bee6e686 100644 --- a/torch/lib/THD/master_worker/master/generic/THDStorage.cpp +++ b/torch/lib/THD/master_worker/master/generic/THDStorage.cpp @@ -180,8 +180,6 @@ void THDStorage_(free)(THDStorage *storage) { THDState::s_current_worker ); - if (storage->flag & TH_STORAGE_VIEW) - THDStorage_(free)(storage->view); delete storage; } } diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index 56e0f1c05798b4..69518bcb09504e 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -40,14 +40,92 @@ def reduce_event(event): def rebuild_tensor(cls, storage, metadata): storage_offset, size, stride = metadata - new_tensor = cls() - new_tensor.set_(storage, storage_offset, size, stride) - return new_tensor + return torch._utils._rebuild_tensor(storage, storage_offset, size, stride) + + +def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset, + storage_cls, storage_device, storage_handle, storage_size): + + storage = storage_from_cache(storage_cls, storage_handle) + if storage is None: + torch.cuda._lazy_init() + storage = storage_cls._new_shared_cuda(storage_device, storage_handle, storage_size) + shared_cache[storage_handle] = storage._weak_ref(StorageRef) + + return torch._utils._rebuild_tensor(storage, tensor_offset, tensor_size, tensor_stride) def reduce_tensor(tensor): - metadata = (tensor.storage_offset(), tensor.size(), tensor.stride()) storage = tensor.storage() + + # Note [CUDA IPC and the caching allocator] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # When you send a CUDA tensor over IPC, you might expect that you will + # get out the same storage from the other end. However, the CUDA caching + # allocator makes it difficult to preserve this invariant. Consider + # the following situation: a tensor of size 0x100 points to offset 0x20 of + # a storage at 0xA100 of size 0x100. (For simplicity, all of these + # sizes are given in bytes). HOWEVER, with the caching allocator, this storage + # might be part of a larger cudaMalloc allocation 0xA000 of size 0x4000. + # + # When we want to send this CUDA tensor over IPC, we must send the + # *entire* cudaMalloc allocation, i.e., the 0xA000 region, not just + # the storage 0xA100 (because that is what CUDA supports). So, on the + # other end, there simply isn't any way to say, "Wait, you gave me + # a bigger region (0xA000) than the one I wanted (0xA100)"; we have + # to just make a storage for the entire caching allocator block. + # + # This is fine, because all we need to do is just adjust the offset + # on the tensor itself: instead of: + # + # Tensor(size=0x100, offset=0x020, storage=Storage(data=0xA100, size=0x0100)) + # + # we have + # + # Tensor(size=0x100, offset=0x120, storage=Storage(data=0xA000, size=0x4000)) + # + # This strategy has a few implications: + # + # 1. When we serialize a CUDA tensor for IPC, we have to do it all in one + # go (non-compositionally), instead of first serializing storage, and + # then serializing tensor. This is because the base address of the + # storage allocation affects what offset we write into the tensor. + # + # 2. We MUST NOT let the new IPC tensor be resizable. Originally, a resize + # of the storage beyond 0x100 would merely have caused us to do a + # reallocation. You don't really want to do this, but if you did, + # all that would happen is that you would lose IPC sharing. But if + # you do this in the new world, we will happily let you write out of + # bounds of your "allocation", clobbering unrelated data in the cached + # allocator block. BAD! + # + # By the way, in old versions of PyTorch, we supported this situation + # natively using a "storage view", which permitted multiple storages to be + # views on each other. But this was the *only* use of storage views, so we + # eliminated it so that we could just use tensor views to implement the same + # thing. + # + if storage.is_cuda: + (device, handle, storage_size, storage_offset) = storage._share_cuda_() + tensor_offset = tensor.storage_offset() + + # WARNING! This call to _weak_ref could lead to O(n) deleter + # behavior, if you repeatedly call it on the same Storage (all + # other sites are guarded by shared_cache; maybe this site + # should be too?) + shared_cache[handle] = storage._weak_ref(StorageRef) + + return (rebuild_cuda_tensor, + (type(tensor), + tensor.size(), + tensor.stride(), + tensor_offset + storage_offset, + type(storage), + device, + handle, + storage_size)) + + metadata = (tensor.storage_offset(), tensor.size(), tensor.stride()) return (rebuild_tensor, (type(tensor), storage, metadata)) @@ -91,16 +169,6 @@ def rebuild_storage_filename(cls, manager, handle, size): return storage._shared_decref() -def rebuild_storage_cuda(cls, device, handle, size, offset, view_size): - storage = storage_from_cache(cls, handle) - if storage is not None: - return storage._new_view(offset, view_size) - torch.cuda._lazy_init() - storage = cls._new_shared_cuda(device, handle, size, offset, view_size) - shared_cache[handle] = storage._weak_ref(StorageRef) - return storage - - def rebuild_storage_empty(cls): return cls() @@ -108,9 +176,7 @@ def rebuild_storage_empty(cls): def reduce_storage(storage): from . import get_sharing_strategy if storage.is_cuda: - metadata = storage._share_cuda_() - cache_key = metadata[1] - rebuild = rebuild_storage_cuda + raise RuntimeError("Cannot pickle CUDA storage; try pickling a CUDA tensor instead") elif get_sharing_strategy() == 'file_system': metadata = storage._share_filename_() cache_key = metadata[1] @@ -146,3 +212,6 @@ def init_reductions(): for t in torch._tensor_classes: ForkingPickler.register(t, reduce_tensor) + + # TODO: Maybe this should be in tensor_classes? :) + ForkingPickler.register(torch.Tensor, reduce_tensor) diff --git a/torch/serialization.py b/torch/serialization.py index 67d5e2d1b4f08b..85c1a83429f26b 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -243,11 +243,13 @@ def persistent_id(obj): return ('module', obj, source_file, source) elif torch.is_storage(obj): storage_type = normalize_storage_type(type(obj)) - root, offset = obj._root_storage() - root_key = str(root._cdata) + # Offset is always 0, but we keep it for backwards compatibility + # with the old serialization format (which supported storage views) + offset = 0 + obj_key = str(obj._cdata) location = location_tag(obj) - serialized_storages[root_key] = root - is_view = obj._cdata != root._cdata + serialized_storages[obj_key] = obj + is_view = obj._cdata != obj._cdata if is_view: view_metadata = (str(obj._cdata), offset, obj.size()) else: @@ -255,9 +257,9 @@ def persistent_id(obj): return ('storage', storage_type, - root_key, + obj_key, location, - root.size(), + obj.size(), view_metadata) return None @@ -449,7 +451,20 @@ def persistent_load(saved_id): storage_views = pickle_module.load(f) for target_cdata, root_cdata, offset, size in storage_views: root = deserialized_objects[root_cdata] - deserialized_objects[target_cdata] = root[offset:offset + size] + if offset != 0 or size != root.size(): + warnings.warn("Detected storage view in legacy serialized data: " + "storage views are no longer natively supported, so we are making " + "a copy of the data instead. THIS IS A SEMANTIC CHANGE! " + "If you need aliasing, reserialize your model using " + "tensors that share storage.") + + tensor = torch._utils._rebuild_tensor(root, offset, (size,), (1,)) + obj = tensor.clone().storage() + else: + # NB: This line does not appear to be exercised by the + # test suite. + obj = root + deserialized_objects[target_cdata] = obj tar.extract('tensors', path=tmpdir) with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f: