From 13415899c5b9dad08956950471631dfdf8960dc6 Mon Sep 17 00:00:00 2001 From: Kevin Chen Date: Thu, 29 Aug 2019 14:09:54 -0700 Subject: [PATCH] Add Int8Transpose operator (#16382) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16382 Adding an Int8TransposeOp that inherits from TransposeOp. Small refactoring to normal TransposeOp to move main logic into a TransposeImpl function. Test Plan: int8_test.cc Differential Revision: D13822715 fbshipit-source-id: 5aac04b0f0a4cc471714221ecb8a0e0ae318834f --- caffe2/operators/quantized/CMakeLists.txt | 3 +- caffe2/operators/quantized/int8_test.cc | 19 ++++++++++ .../operators/quantized/int8_transpose_op.cc | 28 ++++++++++++++ .../operators/quantized/int8_transpose_op.h | 38 +++++++++++++++++++ caffe2/operators/transpose_op.h | 16 +++++--- 5 files changed, 97 insertions(+), 7 deletions(-) create mode 100644 caffe2/operators/quantized/int8_transpose_op.cc create mode 100644 caffe2/operators/quantized/int8_transpose_op.h diff --git a/caffe2/operators/quantized/CMakeLists.txt b/caffe2/operators/quantized/CMakeLists.txt index 3c8f5aa00a350..d5d1a4499f2ae 100644 --- a/caffe2/operators/quantized/CMakeLists.txt +++ b/caffe2/operators/quantized/CMakeLists.txt @@ -21,7 +21,8 @@ list(APPEND Caffe2_CPU_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/int8_roi_align_op.cc" "${CMAKE_CURRENT_SOURCE_DIR}/int8_slice_op.cc" "${CMAKE_CURRENT_SOURCE_DIR}/int8_sigmoid_op.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/int8_softmax_op.cc") + "${CMAKE_CURRENT_SOURCE_DIR}/int8_softmax_op.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/int8_transpose_op.cc") # ---[ CPU test files list(APPEND Caffe2_CPU_TEST_SRCS diff --git a/caffe2/operators/quantized/int8_test.cc b/caffe2/operators/quantized/int8_test.cc index af7a5accf9094..e6f4369628f4e 100644 --- a/caffe2/operators/quantized/int8_test.cc +++ b/caffe2/operators/quantized/int8_test.cc @@ -1002,4 +1002,23 @@ TEST(Int8, Slice) { EXPECT_EQ(YQ.scale, XQ->scale); EXPECT_EQ(YQ.zero_point, XQ->zero_point); } + +TEST(Int8, Transpose) { + auto XQ = q({1, 50, 25, 16}); + auto xop = CreateOperatorDef( + "Int8Transpose", + "", + {"XQ"}, + {"YQ"}, + {MakeArgument("axes", vector{0, 3, 1, 2}), + MakeArgument("Y_scale", XQ->scale), + MakeArgument("Y_zero_point", XQ->zero_point)}); + Workspace ws; + int8Copy(ws.CreateBlob("XQ")->GetMutable(), *XQ); + ws.RunOperatorOnce(xop); + const auto& YQ = ws.GetBlob("YQ")->Get(); + EXPECT_EQ(YQ.t.sizes(), (vector{1, 16, 50, 25})); + EXPECT_EQ(YQ.scale, XQ->scale); + EXPECT_EQ(YQ.zero_point, XQ->zero_point); +} } // namespace caffe2 diff --git a/caffe2/operators/quantized/int8_transpose_op.cc b/caffe2/operators/quantized/int8_transpose_op.cc new file mode 100644 index 0000000000000..e7d5d5133a5bb --- /dev/null +++ b/caffe2/operators/quantized/int8_transpose_op.cc @@ -0,0 +1,28 @@ +#include "caffe2/operators/quantized/int8_transpose_op.h" + +namespace caffe2 { + +REGISTER_CPU_OPERATOR(Int8Transpose, int8::Int8TransposeOp); + +OPERATOR_SCHEMA(Int8Transpose) + .NumInputs(1) + .NumOutputs(1) + .SetDoc(R"DOC( +Transpose the input tensor by permuting the axes of the input according +to the `axes` argument. Similar to numpy's +[transpose](https://docs.scipy.org/doc/numpy/reference/generated/numpy.transpose.html) +function. + +For example, when axes=(1, 0, 2), given an input tensor of shape +(1, 2, 3), the output shape will be (2, 1, 3). +)DOC") + .Arg( + "axes", + "*(type: Tuple(int))* Order to permute axes of input tensor. Reverses " + "the dimensions by default.") + .Arg("Y_scale", "Output tensor quantization scale") + .Arg("Y_zero_point", "Output tensor quantization offset") + .Input(0, "X", "Input tensor") + .Output(0, "Y", "Transposed output"); + +} // namespace caffe2 diff --git a/caffe2/operators/quantized/int8_transpose_op.h b/caffe2/operators/quantized/int8_transpose_op.h new file mode 100644 index 0000000000000..7409ecb448a19 --- /dev/null +++ b/caffe2/operators/quantized/int8_transpose_op.h @@ -0,0 +1,38 @@ +#ifndef CAFFE2_OPERATORS_INT8_TRANSPOSE_OP_H_ +#define CAFFE2_OPERATORS_INT8_TRANSPOSE_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/core/tensor_int8.h" +#include "caffe2/operators/quantized/int8_utils.h" +#include "caffe2/operators/transpose_op.h" + +namespace caffe2 { + +namespace int8 { + +class Int8TransposeOp final : public TransposeOp { + public: + Int8TransposeOp(const OperatorDef& operator_def, Workspace* ws) + : TransposeOp(operator_def, ws) {} + + bool RunOnDevice() override { + auto& X = Inputs()[0]->Get(); + auto* Y = Outputs()[0]->GetMutable(); + int32_t Y_zero_point = + this->template GetSingleArgument("Y_zero_point", 0); + auto Y_scale = this->template GetSingleArgument("Y_scale", 1); + CAFFE_ENFORCE_EQ(Y_zero_point, X.zero_point); + CAFFE_ENFORCE_EQ(Y_scale, X.scale); + Y->scale = Y_scale; + Y->zero_point = Y_zero_point; + TransposeImpl(X.t, &Y->t); + return true; + } +}; + +} // namespace int8 + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_INT8_TRANSPOSE_OP_H_ diff --git a/caffe2/operators/transpose_op.h b/caffe2/operators/transpose_op.h index f84427dbdf054..1bda5fb8443f2 100644 --- a/caffe2/operators/transpose_op.h +++ b/caffe2/operators/transpose_op.h @@ -11,7 +11,7 @@ namespace caffe2 { template -class TransposeOp final : public Operator { +class TransposeOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; USE_DISPATCH_HELPER; @@ -36,11 +36,9 @@ class TransposeOp final : public Operator { this, Input(0)); } - private: + protected: template - bool DoRunWithType() { - const auto& X = Input(0); - + void TransposeImpl(const Tensor& X, Tensor* Y) { const int ndim = X.dim(); if (axes_.empty()) { axes_.resize(ndim); @@ -53,7 +51,7 @@ class TransposeOp final : public Operator { for (int i = 0; i < ndim; ++i) { Y_dims[i] = X_dims[axes_[i]]; } - auto* Y = Output(0, Y_dims, at::dtype()); + Y->Resize(Y_dims); math::Transpose( X_dims.size(), X_dims.data(), @@ -61,6 +59,12 @@ class TransposeOp final : public Operator { X.template data(), Y->template mutable_data(), &context_); + } + + private: + template + bool DoRunWithType() { + TransposeImpl(Input(0), Output(0)); return true; }