Skip to content

Commit 01fd409

Browse files
committed
Fuse computations into the Tensor contractions using output kernel
1 parent 5539587 commit 01fd409

File tree

6 files changed

+248
-37
lines changed

6 files changed

+248
-37
lines changed

unsupported/Eigen/CXX11/src/Tensor/TensorBase.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,9 +517,15 @@ class TensorBase<Derived, ReadOnlyAccessors>
517517
typedef Eigen::IndexPair<Index> DimensionPair;
518518

519519
template<typename OtherDerived, typename Dimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
520-
const TensorContractionOp<const Dimensions, const Derived, const OtherDerived>
520+
const TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const NoOpOutputKernel>
521521
contract(const OtherDerived& other, const Dimensions& dims) const {
522-
return TensorContractionOp<const Dimensions, const Derived, const OtherDerived>(derived(), other.derived(), dims);
522+
return TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const NoOpOutputKernel>(derived(), other.derived(), dims);
523+
}
524+
525+
template<typename OtherDerived, typename Dimensions, typename OutputKernel> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
526+
const TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const OutputKernel>
527+
contract(const OtherDerived& other, const Dimensions& dims, const OutputKernel& output_kernel) const {
528+
return TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const OutputKernel>(derived(), other.derived(), dims, output_kernel);
523529
}
524530

525531
// Convolutions.

unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h

Lines changed: 92 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ template<typename LhsScalar, typename RhsScalar, typename Scalar>
8585
#endif
8686

8787

88-
template<typename Dimensions, typename LhsXprType, typename RhsXprType>
89-
struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
88+
template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
89+
struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> >
9090
{
9191
// Type promotion to handle the case where the types of the lhs and the rhs are different.
9292
typedef typename gebp_traits<typename remove_const<typename LhsXprType::Scalar>::type,
@@ -112,23 +112,24 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
112112
};
113113
};
114114

115-
template<typename Dimensions, typename LhsXprType, typename RhsXprType>
116-
struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, Eigen::Dense>
115+
template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
116+
struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>, Eigen::Dense>
117117
{
118-
typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType>& type;
118+
typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>& type;
119119
};
120120

121-
template<typename Dimensions, typename LhsXprType, typename RhsXprType>
122-
struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >::type>
121+
template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
122+
struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> >::type>
123123
{
124-
typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType> type;
124+
typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> type;
125125
};
126126

127-
template<typename Indices_, typename LeftArgType_, typename RightArgType_, typename Device_>
128-
struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_>, Device_> > {
127+
template<typename Indices_, typename LeftArgType_, typename RightArgType_, typename OutputKernelType_, typename Device_>
128+
struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_, OutputKernelType_>, Device_> > {
129129
typedef Indices_ Indices;
130130
typedef LeftArgType_ LeftArgType;
131131
typedef RightArgType_ RightArgType;
132+
typedef OutputKernelType_ OutputKernelType;
132133
typedef Device_ Device;
133134

134135
// From NumDims below.
@@ -137,8 +138,52 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_,
137138

138139
} // end namespace internal
139140

140-
template<typename Indices, typename LhsXprType, typename RhsXprType>
141-
class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType>, ReadOnlyAccessors>
141+
// Tensor contraction params that should enable to get from output matrix
142+
// 2-dimensional coordinates to the output tensor dimensions.
143+
struct TensorContractionParams {
144+
// TensorContraction evaluator assumes that both tensors are in ColMajor
145+
// layout, if tensors are in RowMajor evaluator swap lhs with rhs.
146+
bool swapped_arguments;
147+
};
148+
149+
// Output kernel allows to fuse operations into the tensor contraction.
150+
//
151+
// Examples:
152+
// 1. Elementwise Relu transformation following Conv2D.
153+
// 2. AddBias to the Conv2D output channels dimension.
154+
//
155+
// See expected implementation in NoOpOutputKernel.
156+
struct OutputKernel {
157+
template <typename Index, typename Scalar>
158+
using OutputMapper = internal::blas_data_mapper<Scalar, Index, ColMajor>;
159+
};
160+
161+
// Output kernel that does absolutely nothing.
162+
struct NoOpOutputKernel {
163+
/**
164+
* Tensor contraction evaluator calls this kernel after finishing each block
165+
* of output matrix. Output blocks belong to the 2-dimensional output tensor.
166+
*
167+
* TensorContractionParams contains contraction dimensions information
168+
* required to map output 2-d space into the expected output tensor space
169+
* (potentially higher dimensional).
170+
*
171+
* \param[in] output_mapper Access to output tensor memory
172+
* \param[in] params Tensor contraction parameters
173+
* \param[in] i Index of a first row available through output_mapper
174+
* \param[in] j Index of a first column available through output_mapper
175+
* \param[in] num_rows Number of available rows
176+
* \param[in] num_cols Number of available columns
177+
*/
178+
template <typename Index, typename Scalar>
179+
EIGEN_ALWAYS_INLINE void operator()(
180+
const OutputKernel::OutputMapper<Index, Scalar>& output_mapper,
181+
const TensorContractionParams& params, Index i, Index j, Index num_rows,
182+
Index num_cols) const {}
183+
};
184+
185+
template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
186+
class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>, ReadOnlyAccessors>
142187
{
143188
public:
144189
typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar;
@@ -149,8 +194,10 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp
149194
typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index;
150195

151196
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(
152-
const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims)
153-
: m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {}
197+
const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims,
198+
const OutputKernelType& output_kernel = OutputKernelType())
199+
: m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims),
200+
m_output_kernel(output_kernel) {}
154201

155202
EIGEN_DEVICE_FUNC
156203
const Indices& indices() const { return m_indices; }
@@ -164,10 +211,14 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp
164211
const typename internal::remove_all<typename RhsXprType::Nested>::type&
165212
rhsExpression() const { return m_rhs_xpr; }
166213

214+
EIGEN_DEVICE_FUNC
215+
const OutputKernelType& outputKernel() const { return m_output_kernel; }
216+
167217
protected:
168218
typename LhsXprType::Nested m_lhs_xpr;
169219
typename RhsXprType::Nested m_rhs_xpr;
170220
const Indices m_indices;
221+
const OutputKernelType m_output_kernel;
171222
};
172223

173224

@@ -177,9 +228,10 @@ struct TensorContractionEvaluatorBase
177228
typedef typename internal::traits<Derived>::Indices Indices;
178229
typedef typename internal::traits<Derived>::LeftArgType LeftArgType;
179230
typedef typename internal::traits<Derived>::RightArgType RightArgType;
231+
typedef typename internal::traits<Derived>::OutputKernelType OutputKernelType;
180232
typedef typename internal::traits<Derived>::Device Device;
181233

182-
typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
234+
typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
183235
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
184236
typedef typename XprType::Index Index;
185237
typedef typename XprType::CoeffReturnType CoeffReturnType;
@@ -221,6 +273,7 @@ struct TensorContractionEvaluatorBase
221273
op.lhsExpression(), op.rhsExpression()), device),
222274
m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
223275
op.rhsExpression(), op.lhsExpression()), device),
276+
m_output_kernel(op.outputKernel()),
224277
m_device(device),
225278
m_result(NULL) {
226279
EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
@@ -391,6 +444,13 @@ struct TensorContractionEvaluatorBase
391444
numext::swap(m_dimensions[i], m_dimensions[j]);
392445
}
393446
}
447+
448+
// A set of parameters that will allow output kernel to get from output
449+
// tensor dimensions (i, j) into the original tensor dimensions.
450+
// TODO(ezhulenev): Add parameters required to infer output tensor index for
451+
// more complex contractions than 2x2 on internal dimension.
452+
m_tensor_contraction_params = {
453+
/**swapped_arguments=*/static_cast<int>(Layout) == RowMajor};
394454
}
395455

396456
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
@@ -585,7 +645,15 @@ struct TensorContractionEvaluatorBase
585645

586646
// call gebp (matrix kernel)
587647
// The parameters here are copied from Eigen's GEMM implementation
588-
gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, Scalar(1), -1, -1, 0, 0);
648+
const auto output_mapper = output.getSubMapper(i2, j2);
649+
gebp(output_mapper, blockA, blockB, actual_mc, actual_kc, actual_nc,
650+
Scalar(1), -1, -1, 0, 0);
651+
652+
// We are done with this [i2, j2] output block.
653+
if (k2 + kc >= k) {
654+
m_output_kernel(output_mapper, m_tensor_contraction_params, i2, j2,
655+
actual_mc, actual_nc);
656+
}
589657
}
590658
}
591659
}
@@ -848,23 +916,26 @@ struct TensorContractionEvaluatorBase
848916
Index m_j_size;
849917
Index m_k_size;
850918

919+
TensorContractionParams m_tensor_contraction_params;
920+
851921
TensorEvaluator<EvalLeftArgType, Device> m_leftImpl;
852922
TensorEvaluator<EvalRightArgType, Device> m_rightImpl;
853923
const Device& m_device;
924+
OutputKernelType m_output_kernel;
854925
Scalar* m_result;
855926
bool m_can_use_xsmm;
856927
};
857928

858929

859930
// evaluator for default device
860-
template<typename Indices, typename LeftArgType, typename RightArgType, typename Device>
861-
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> :
931+
template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType, typename Device>
932+
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> :
862933
public TensorContractionEvaluatorBase<
863-
TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > {
864-
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
934+
TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> > {
935+
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
865936
typedef TensorContractionEvaluatorBase<Self> Base;
866937

867-
typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
938+
typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
868939
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
869940
typedef typename XprType::Index Index;
870941
typedef typename XprType::CoeffReturnType CoeffReturnType;

unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,16 @@ struct packRhsAndKernelArg {
5656
} // end namespace internal
5757
#endif // EIGEN_USE_SIMPLE_THREAD_POOL
5858

59-
template<typename Indices, typename LeftArgType, typename RightArgType>
60-
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> :
61-
public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> > {
59+
template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
60+
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> :
61+
public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> > {
6262

6363
typedef ThreadPoolDevice Device;
6464

65-
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
65+
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
6666
typedef TensorContractionEvaluatorBase<Self> Base;
6767

68-
typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
68+
typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
6969
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
7070
typedef typename XprType::Index Index;
7171
typedef typename XprType::CoeffReturnType CoeffReturnType;
@@ -308,7 +308,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
308308
this->m_k_strides);
309309

310310
Context<LhsPacker, RhsPacker, GebpKernel, LhsMapper, RhsMapper,
311-
OutputMapper>(this->m_device, num_threads, lhs, rhs, buffer, m, n,
311+
OutputMapper>(this, num_threads, lhs, rhs, buffer, m, n,
312312
k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0,
313313
shard_by_col, parallel_pack)
314314
.run();
@@ -319,16 +319,18 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
319319
typename LhsMapper, typename RhsMapper, typename OutputMapper>
320320
class Context {
321321
public:
322-
Context(const Device& device, int num_threads, LhsMapper& lhs,
322+
Context(const Self* self, int num_threads, LhsMapper& lhs,
323323
RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm,
324324
Index bn, Index bk, Index nm, Index nn, Index nk, Index gm,
325325
Index gn, Index nm0, Index nn0, bool shard_by_col,
326326
bool parallel_pack)
327-
: device_(device),
327+
: device_(self->m_device),
328328
lhs_(lhs),
329329
rhs_(rhs),
330330
buffer_(buffer),
331331
output_(buffer, tm),
332+
output_kernel_(self->m_output_kernel),
333+
tensor_contraction_params_(self->m_tensor_contraction_params),
332334
num_threads_(num_threads),
333335
shard_by_col_(shard_by_col),
334336
parallel_pack_(parallel_pack),
@@ -420,6 +422,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
420422
RhsMapper& rhs_;
421423
Scalar* const buffer_;
422424
OutputMapper output_;
425+
OutputKernelType output_kernel_;
426+
TensorContractionParams tensor_contraction_params_;
423427
const int num_threads_;
424428
const bool shard_by_col_;
425429
const bool parallel_pack_;
@@ -536,19 +540,32 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
536540
const Index mend = m * gm_ + gm(m);
537541
if (shard_by_col_) {
538542
for (Index n1 = n * gn_; n1 < nend; n1++) {
539-
for (Index m1 = m * gm_; m1 < mend; m1++)
540-
GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_),
541-
packed_lhs_[k % (P - 1)][m1],
543+
for (Index m1 = m * gm_; m1 < mend; m1++) {
544+
const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
545+
GebpKernel()(output_mapper, packed_lhs_[k % (P - 1)][m1],
542546
packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
543547
Scalar(1), -1, -1, 0, 0);
548+
549+
// We are done with the last task for the [m1, n1] block.
550+
if (k + 1 == nk_) {
551+
output_kernel_(output_mapper, tensor_contraction_params_,
552+
m1 * bm_, n1 * bn_, bm(m1), bn(n1));
553+
}
554+
}
544555
}
545556
} else {
546557
for (Index m1 = m * gm_; m1 < mend; m1++)
547558
for (Index n1 = n * gn_; n1 < nend; n1++) {
548-
GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_),
549-
packed_lhs_[k % (P - 1)][m1],
559+
const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
560+
GebpKernel()(output_mapper, packed_lhs_[k % (P - 1)][m1],
550561
packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
551562
Scalar(1), -1, -1, 0, 0);
563+
564+
// We are done with the last task for the [m1, n1] block.
565+
if (k + 1 == nk_) {
566+
output_kernel_(output_mapper, tensor_contraction_params_,
567+
m1 * bm_, n1 * bn_, bm(m1), bn(n1));
568+
}
552569
}
553570
}
554571
signal_kernel(m, n, k + 1, false);
@@ -747,6 +764,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
747764
}
748765

749766
#else // EIGEN_USE_SIMPLE_THREAD_POOL
767+
// TODO(ezhulenev): SimpleThreadPool will be removed in the future, and seems
768+
// like it's not worth adding output kernel support here.
769+
static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
770+
"SimpleThreadPool does not support contraction output kernels.");
750771

751772
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
752773
void evalProduct(Scalar* buffer) const {
@@ -1065,6 +1086,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
10651086
}
10661087

10671088
#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM)
1089+
// TODO(ezhulenev): Add support for output kernels and LIBXSMM.
1090+
static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
1091+
"XSMM does not support contraction output kernels.");
1092+
10681093
template<int Alignment>
10691094
class ContextXsmm {
10701095
public:

unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ template<typename Op, typename Dims, typename XprType, template <class> class Ma
6565
template<typename XprType> class TensorIndexTupleOp;
6666
template<typename ReduceOp, typename Dims, typename XprType> class TensorTupleReducerOp;
6767
template<typename Axis, typename LeftXprType, typename RightXprType> class TensorConcatenationOp;
68-
template<typename Dimensions, typename LeftXprType, typename RightXprType> class TensorContractionOp;
68+
template<typename Dimensions, typename LeftXprType, typename RightXprType, typename OutputKernelType> class TensorContractionOp;
6969
template<typename TargetType, typename XprType> class TensorConversionOp;
7070
template<typename Dimensions, typename InputXprType, typename KernelXprType> class TensorConvolutionOp;
7171
template<typename FFT, typename XprType, int FFTDataType, int FFTDirection> class TensorFFTOp;
@@ -97,6 +97,8 @@ template<typename XprType> class TensorForcedEvalOp;
9797
template<typename ExpressionType, typename DeviceType> class TensorDevice;
9898
template<typename Derived, typename Device> struct TensorEvaluator;
9999

100+
class NoOpOutputKernel;
101+
100102
struct DefaultDevice;
101103
struct ThreadPoolDevice;
102104
struct GpuDevice;

0 commit comments

Comments
 (0)