@@ -85,8 +85,8 @@ template<typename LhsScalar, typename RhsScalar, typename Scalar>
85
85
#endif
86
86
87
87
88
- template <typename Dimensions, typename LhsXprType, typename RhsXprType>
89
- struct traits <TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
88
+ template <typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType >
89
+ struct traits <TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType > >
90
90
{
91
91
// Type promotion to handle the case where the types of the lhs and the rhs are different.
92
92
typedef typename gebp_traits<typename remove_const<typename LhsXprType::Scalar>::type,
@@ -112,23 +112,24 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
112
112
};
113
113
};
114
114
115
- template <typename Dimensions, typename LhsXprType, typename RhsXprType>
116
- struct eval <TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, Eigen::Dense>
115
+ template <typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType >
116
+ struct eval <TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType >, Eigen::Dense>
117
117
{
118
- typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType>& type;
118
+ typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType >& type;
119
119
};
120
120
121
- template <typename Dimensions, typename LhsXprType, typename RhsXprType>
122
- struct nested <TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, 1 , typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >::type>
121
+ template <typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType >
122
+ struct nested <TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType >, 1 , typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType > >::type>
123
123
{
124
- typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType> type;
124
+ typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType > type;
125
125
};
126
126
127
- template <typename Indices_, typename LeftArgType_, typename RightArgType_, typename Device_>
128
- struct traits <TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_>, Device_> > {
127
+ template <typename Indices_, typename LeftArgType_, typename RightArgType_, typename OutputKernelType_, typename Device_>
128
+ struct traits <TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_, OutputKernelType_ >, Device_> > {
129
129
typedef Indices_ Indices;
130
130
typedef LeftArgType_ LeftArgType;
131
131
typedef RightArgType_ RightArgType;
132
+ typedef OutputKernelType_ OutputKernelType;
132
133
typedef Device_ Device;
133
134
134
135
// From NumDims below.
@@ -137,8 +138,52 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_,
137
138
138
139
} // end namespace internal
139
140
140
- template <typename Indices, typename LhsXprType, typename RhsXprType>
141
- class TensorContractionOp : public TensorBase <TensorContractionOp<Indices, LhsXprType, RhsXprType>, ReadOnlyAccessors>
141
+ // Tensor contraction params that should enable to get from output matrix
142
+ // 2-dimensional coordinates to the output tensor dimensions.
143
+ struct TensorContractionParams {
144
+ // TensorContraction evaluator assumes that both tensors are in ColMajor
145
+ // layout, if tensors are in RowMajor evaluator swap lhs with rhs.
146
+ bool swapped_arguments;
147
+ };
148
+
149
+ // Output kernel allows to fuse operations into the tensor contraction.
150
+ //
151
+ // Examples:
152
+ // 1. Elementwise Relu transformation following Conv2D.
153
+ // 2. AddBias to the Conv2D output channels dimension.
154
+ //
155
+ // See expected implementation in NoOpOutputKernel.
156
+ struct OutputKernel {
157
+ template <typename Index, typename Scalar>
158
+ using OutputMapper = internal::blas_data_mapper<Scalar, Index, ColMajor>;
159
+ };
160
+
161
+ // Output kernel that does absolutely nothing.
162
+ struct NoOpOutputKernel {
163
+ /* *
164
+ * Tensor contraction evaluator calls this kernel after finishing each block
165
+ * of output matrix. Output blocks belong to the 2-dimensional output tensor.
166
+ *
167
+ * TensorContractionParams contains contraction dimensions information
168
+ * required to map output 2-d space into the expected output tensor space
169
+ * (potentially higher dimensional).
170
+ *
171
+ * \param[in] output_mapper Access to output tensor memory
172
+ * \param[in] params Tensor contraction parameters
173
+ * \param[in] i Index of a first row available through output_mapper
174
+ * \param[in] j Index of a first column available through output_mapper
175
+ * \param[in] num_rows Number of available rows
176
+ * \param[in] num_cols Number of available columns
177
+ */
178
+ template <typename Index, typename Scalar>
179
+ EIGEN_ALWAYS_INLINE void operator ()(
180
+ const OutputKernel::OutputMapper<Index, Scalar>& output_mapper,
181
+ const TensorContractionParams& params, Index i, Index j, Index num_rows,
182
+ Index num_cols) const {}
183
+ };
184
+
185
+ template <typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
186
+ class TensorContractionOp : public TensorBase <TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>, ReadOnlyAccessors>
142
187
{
143
188
public:
144
189
typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar;
@@ -149,8 +194,10 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp
149
194
typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index;
150
195
151
196
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp (
152
- const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims)
153
- : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {}
197
+ const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims,
198
+ const OutputKernelType& output_kernel = OutputKernelType())
199
+ : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims),
200
+ m_output_kernel(output_kernel) {}
154
201
155
202
EIGEN_DEVICE_FUNC
156
203
const Indices& indices () const { return m_indices; }
@@ -164,10 +211,14 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp
164
211
const typename internal::remove_all<typename RhsXprType::Nested>::type&
165
212
rhsExpression () const { return m_rhs_xpr; }
166
213
214
+ EIGEN_DEVICE_FUNC
215
+ const OutputKernelType& outputKernel () const { return m_output_kernel; }
216
+
167
217
protected:
168
218
typename LhsXprType::Nested m_lhs_xpr;
169
219
typename RhsXprType::Nested m_rhs_xpr;
170
220
const Indices m_indices;
221
+ const OutputKernelType m_output_kernel;
171
222
};
172
223
173
224
@@ -177,9 +228,10 @@ struct TensorContractionEvaluatorBase
177
228
typedef typename internal::traits<Derived>::Indices Indices;
178
229
typedef typename internal::traits<Derived>::LeftArgType LeftArgType;
179
230
typedef typename internal::traits<Derived>::RightArgType RightArgType;
231
+ typedef typename internal::traits<Derived>::OutputKernelType OutputKernelType;
180
232
typedef typename internal::traits<Derived>::Device Device;
181
233
182
- typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
234
+ typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType > XprType;
183
235
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
184
236
typedef typename XprType::Index Index;
185
237
typedef typename XprType::CoeffReturnType CoeffReturnType;
@@ -221,6 +273,7 @@ struct TensorContractionEvaluatorBase
221
273
op.lhsExpression(), op.rhsExpression()), device),
222
274
m_rightImpl (choose(Cond<static_cast <int >(Layout) == static_cast<int>(ColMajor)>(),
223
275
op.rhsExpression(), op.lhsExpression()), device),
276
+ m_output_kernel(op.outputKernel()),
224
277
m_device(device),
225
278
m_result(NULL ) {
226
279
EIGEN_STATIC_ASSERT ((static_cast <int >(TensorEvaluator<LeftArgType, Device>::Layout) ==
@@ -391,6 +444,13 @@ struct TensorContractionEvaluatorBase
391
444
numext::swap (m_dimensions[i], m_dimensions[j]);
392
445
}
393
446
}
447
+
448
+ // A set of parameters that will allow output kernel to get from output
449
+ // tensor dimensions (i, j) into the original tensor dimensions.
450
+ // TODO(ezhulenev): Add parameters required to infer output tensor index for
451
+ // more complex contractions than 2x2 on internal dimension.
452
+ m_tensor_contraction_params = {
453
+ /* *swapped_arguments=*/ static_cast <int >(Layout) == RowMajor};
394
454
}
395
455
396
456
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions () const { return m_dimensions; }
@@ -585,7 +645,15 @@ struct TensorContractionEvaluatorBase
585
645
586
646
// call gebp (matrix kernel)
587
647
// The parameters here are copied from Eigen's GEMM implementation
588
- gebp (output.getSubMapper (i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, Scalar (1 ), -1 , -1 , 0 , 0 );
648
+ const auto output_mapper = output.getSubMapper (i2, j2);
649
+ gebp (output_mapper, blockA, blockB, actual_mc, actual_kc, actual_nc,
650
+ Scalar (1 ), -1 , -1 , 0 , 0 );
651
+
652
+ // We are done with this [i2, j2] output block.
653
+ if (k2 + kc >= k) {
654
+ m_output_kernel (output_mapper, m_tensor_contraction_params, i2, j2,
655
+ actual_mc, actual_nc);
656
+ }
589
657
}
590
658
}
591
659
}
@@ -848,23 +916,26 @@ struct TensorContractionEvaluatorBase
848
916
Index m_j_size;
849
917
Index m_k_size;
850
918
919
+ TensorContractionParams m_tensor_contraction_params;
920
+
851
921
TensorEvaluator<EvalLeftArgType, Device> m_leftImpl;
852
922
TensorEvaluator<EvalRightArgType, Device> m_rightImpl;
853
923
const Device& m_device;
924
+ OutputKernelType m_output_kernel;
854
925
Scalar* m_result;
855
926
bool m_can_use_xsmm;
856
927
};
857
928
858
929
859
930
// evaluator for default device
860
- template <typename Indices, typename LeftArgType, typename RightArgType, typename Device>
861
- struct TensorEvaluator <const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> :
931
+ template <typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType, typename Device>
932
+ struct TensorEvaluator <const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType >, Device> :
862
933
public TensorContractionEvaluatorBase<
863
- TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > {
864
- typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
934
+ TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType >, Device> > {
935
+ typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType >, Device> Self;
865
936
typedef TensorContractionEvaluatorBase<Self> Base;
866
937
867
- typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
938
+ typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType > XprType;
868
939
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
869
940
typedef typename XprType::Index Index;
870
941
typedef typename XprType::CoeffReturnType CoeffReturnType;
0 commit comments