Skip to content

Commit e058955

Browse files
committed
Add axis to the batched reduce add
1 parent cd11b90 commit e058955

File tree

14 files changed

+270
-84
lines changed

14 files changed

+270
-84
lines changed

include/glow/Base/Tensor.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ class Tensor;
3939
void genericTranspose(Tensor *src, Tensor *dest,
4040
llvm::ArrayRef<unsigned> shuffle);
4141

42+
/// Helper function that \returns a ShapeVector of those dimensions in \p
43+
/// currDims expanded with dimension = 1 until the maximum tensor dimension is
44+
/// reached. The number of elements in the input dims is the same as in the
45+
/// returned dims. For example, input {2,1,4} would result in {2,1,4,1,1,1}.
46+
ShapeVector expandDimsToMax(llvm::ArrayRef<size_t> currDims);
47+
4248
/// A class that represents a contiguous n-dimensional array (a tensor).
4349
class Tensor final {
4450
/// A pointer to the tensor data.

include/glow/Graph/Graph.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,10 +350,11 @@ class Function final : public Named {
350350
NodeValue rhs);
351351

352352
BatchedReduceAddNode *createBatchedReduceAdd(llvm::StringRef name,
353-
NodeValue batch);
353+
NodeValue batch, size_t axis);
354354

355355
BatchedReduceAddNode *createBatchedReduceAdd(llvm::StringRef name,
356-
TypeRef outTy, NodeValue batch);
356+
TypeRef outTy, NodeValue batch,
357+
size_t axis);
357358

358359
BatchedAddNode *createBatchedAdd(llvm::StringRef name, NodeValue batch,
359360
NodeValue sample);

lib/Backends/CPU/LLVMIRGen.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ void LLVMIRGen::initCodeGen() {
217217

218218
/// \returns the LLVM type corresponding to the type of elements stored in \p
219219
/// val.
220-
llvm::Type *LLVMIRGen::getElementType(llvm::IRBuilder<> &builder, const Value *val) {
220+
llvm::Type *LLVMIRGen::getElementType(llvm::IRBuilder<> &builder,
221+
const Value *val) {
221222
switch (val->getElementType()) {
222223
case ElemKind::IndexTy:
223224
return builder.getIntNTy(sizeof(size_t) * 8);
@@ -1305,11 +1306,14 @@ void LLVMIRGen::generateLLVMIRForInstr(llvm::IRBuilder<> &builder,
13051306
auto *batch = BR->getBatch();
13061307
auto *destPtr = emitValueAddress(builder, dest);
13071308
auto *batchPtr = emitValueAddress(builder, batch);
1309+
auto *axis = emitConstSizeT(builder, BR->getAxis());
13081310

1309-
auto *destSize = emitConstSizeT(builder, dest->size());
1310-
auto bdim = flattenCdr(batch->dims());
1311-
auto *numSlice = emitConstSizeT(builder, bdim.first);
1312-
auto *sliceSize = emitConstSizeT(builder, bdim.second);
1311+
ShapeVector eBatchDims = expandDimsToMax(batch->dims());
1312+
ShapeVector eDestDims = eBatchDims;
1313+
eDestDims[BR->getAxis()] = 1;
1314+
1315+
auto *batchDims = emitConstArray(builder, eBatchDims);
1316+
auto *destDims = emitConstArray(builder, eDestDims);
13131317

13141318
auto *F = getFunction("batchedreduceadd", dest->getElementType());
13151319

@@ -1332,11 +1336,13 @@ void LLVMIRGen::generateLLVMIRForInstr(llvm::IRBuilder<> &builder,
13321336
auto *batchScale = emitConstI32(builder, batchScaleParams.scale_);
13331337

13341338
createCall(builder, F,
1335-
{destPtr, batchPtr, destSize, numSlice, sliceSize, destOffset,
1336-
batchOffset, batchPre, batchPost, batchScale});
1339+
{destPtr, batchPtr, destDims, batchDims, destOffset,
1340+
batchOffset, batchPre, batchPost, batchScale, axis});
13371341
} else {
1342+
auto *destSize = emitConstSizeT(builder, dest->size());
1343+
13381344
createCall(builder, F,
1339-
{destPtr, batchPtr, destSize, numSlice, sliceSize});
1345+
{destPtr, batchPtr, destSize, destDims, batchDims, axis});
13401346
}
13411347
break;
13421348
}

lib/Backends/CPU/libjit/libjit.cpp

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -694,32 +694,69 @@ void libjit_batchedadd_i8(int8_t *dest, const int8_t *batch,
694694
}
695695
}
696696

697+
/// The dimensions passed in here are pre-expanded in LLVMIRGen with 1s so that
698+
/// we can iterate over the shape here, regardless of the shape of the tensor.
697699
void libjit_batchedreduceadd_f(float *dest, const float *batch, size_t destSize,
698-
size_t numSlice, size_t sliceSize) {
699-
for (size_t i = 0; i < destSize; i++) {
700+
const size_t *destDims, const size_t *batchDims,
701+
size_t axis) {
702+
for (size_t i = 0; i < destSize; i++)
700703
dest[i] = 0.0;
701-
}
702-
for (size_t n = 0; n < numSlice; n++) {
703-
size_t base = n * sliceSize;
704-
for (size_t i = 0; i < sliceSize; i++) {
705-
dest[i] += batch[base + i];
706-
}
707-
}
704+
705+
for (size_t x = 0; x < batchDims[0]; x++)
706+
for (size_t y = 0; y < batchDims[1]; y++)
707+
for (size_t z = 0; z < batchDims[2]; z++)
708+
for (size_t w = 0; w < batchDims[3]; w++)
709+
for (size_t q = 0; q < batchDims[4]; q++)
710+
for (size_t r = 0; r < batchDims[5]; r++) {
711+
size_t I[] = {x, y, z, w, q, r};
712+
I[axis] = 0;
713+
dest[libjit_getXYZWQR(destDims, I[0], I[1], I[2], I[3], I[4],
714+
I[5])] +=
715+
batch[libjit_getXYZWQR(batchDims, x, y, z, w, q, r)];
716+
}
708717
}
709718

719+
/// Same as the non-quantized version, the dimensions here are pre-expanded in
720+
/// LLVMIRGen. However, for quantization, we must accumulate in the inner-most
721+
/// loop with higher precision (int32_t) and then clip the result back into the
722+
/// dest tensor. Thus we add max_tensor_dimensions different cases for this to
723+
/// ensure the axis is used as the inner-most loop.
710724
void libjit_batchedreduceadd_i8(int8_t *dest, const int8_t *batch,
711-
size_t destSize, size_t numSlice,
712-
size_t sliceSize, int32_t destOffset,
713-
int32_t batchOffset, int32_t batchPre,
714-
int32_t batchPost, int32_t batchScale) {
715-
for (size_t i = 0; i < sliceSize; i++) {
716-
int32_t sum = 0;
717-
for (size_t n = 0; n < numSlice; n++) {
718-
sum += batch[n * sliceSize + i] - batchOffset;
719-
}
720-
int32_t q =
721-
libjit_scale_i32i8(sum, batchPre, batchPost, batchScale, destOffset);
722-
dest[i] = libjit_clip(q);
725+
const size_t *destDims, const size_t *batchDims,
726+
int32_t destOffset, int32_t batchOffset,
727+
int32_t batchPre, int32_t batchPost,
728+
int32_t batchScale, size_t axis) {
729+
switch (axis) {
730+
#define LOOP_AXIS_CASE(_D0, _D1, _D2, _D3, _D4, _D5_AXIS) \
731+
case _D5_AXIS: \
732+
for (size_t i##_D0 = 0; i##_D0 < batchDims[_D0]; i##_D0++) \
733+
for (size_t i##_D1 = 0; i##_D1 < batchDims[_D1]; i##_D1++) \
734+
for (size_t i##_D2 = 0; i##_D2 < batchDims[_D2]; i##_D2++) \
735+
for (size_t i##_D3 = 0; i##_D3 < batchDims[_D3]; i##_D3++) \
736+
for (size_t i##_D4 = 0; i##_D4 < batchDims[_D4]; i##_D4++) { \
737+
int32_t sum = 0.0; \
738+
for (size_t i##_D5_AXIS = 0; i##_D5_AXIS < batchDims[_D5_AXIS]; \
739+
i##_D5_AXIS++) { \
740+
sum += batch[libjit_getXYZWQR(batchDims, i0, i1, i2, i3, i4, \
741+
i5)] - \
742+
batchOffset; \
743+
} \
744+
size_t i##_D5_AXIS = 0; \
745+
int32_t res = libjit_scale_i32i8(sum, batchPre, batchPost, \
746+
batchScale, destOffset); \
747+
dest[libjit_getXYZWQR(destDims, i0, i1, i2, i3, i4, i5)] = \
748+
libjit_clip(res); \
749+
} \
750+
return;
751+
752+
// Each loop order, with the inner-most dimension/index equal to the axis.
753+
LOOP_AXIS_CASE(1, 2, 3, 4, 5, 0);
754+
LOOP_AXIS_CASE(0, 2, 3, 4, 5, 1);
755+
LOOP_AXIS_CASE(0, 1, 3, 4, 5, 2);
756+
LOOP_AXIS_CASE(0, 1, 2, 4, 5, 3);
757+
LOOP_AXIS_CASE(0, 1, 2, 3, 5, 4);
758+
LOOP_AXIS_CASE(0, 1, 2, 3, 4, 5);
759+
#undef LOOP_AXIS_CASE
723760
}
724761
}
725762

lib/Backends/CPU/libjit/libjit_defs.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ inline void AdduFloat8(float *p, float8 v) {
5454
StoreuFloat8(p, LoaduFloat8(p) + v);
5555
}
5656

57+
/// \returns the index of the element at x,y,z,w,q,r.
58+
inline size_t libjit_getXYZWQR(const size_t *dims, size_t x, size_t y, size_t z,
59+
size_t w, size_t q, size_t r) {
60+
return (x * dims[1] * dims[2] * dims[3] * dims[4] * dims[5]) +
61+
(y * dims[2] * dims[3] * dims[4] * dims[5]) +
62+
(z * dims[3] * dims[4] * dims[5]) + (w * dims[4] * dims[5]) +
63+
(q * dims[5]) + r;
64+
}
65+
5766
/// \returns the index of the element at x,y,z,w,q.
5867
inline size_t libjit_getXYZWQ(const size_t *dims, size_t x, size_t y, size_t z,
5968
size_t w, size_t q) {

lib/Backends/Interpreter/InterpreterNodes.cpp

Lines changed: 81 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,54 +1302,99 @@ void Interpreter::fwdBatchedAddInst(const glow::BatchedAddInst *I) {
13021302
}
13031303

13041304
void Interpreter::fwdBatchedReduceAddInst(const glow::BatchedReduceAddInst *I) {
1305-
if (getTensor(I->getBatch())->getType().isQuantizedType()) {
1306-
auto dest = getWeightHandle<int8_t>(I->getDest());
1307-
auto batch = getWeightHandle<int8_t>(I->getBatch());
1305+
static_assert(max_tensor_dimensions == 6,
1306+
"Loops below assume max_tensor_dimensions = 6.");
13081307

1309-
auto destTy = I->getDest()->getType();
1310-
auto batchTy = I->getBatch()->getType();
1308+
auto *batch = I->getBatch();
1309+
auto *dest = I->getDest();
1310+
const auto axis = I->getAxis();
1311+
1312+
// Initialize both expanded batch and dest dims to the expanded batch
1313+
// dims. This allows us below to iterate over the tensor regardless of its
1314+
// shape using max_tensor_dimensions loops below.
1315+
ShapeVector eBatchDims = expandDimsToMax(batch->dims());
1316+
ShapeVector eDestDims = eBatchDims;
1317+
1318+
// Set the destination axis dimension (the one we are reducing) to 1.
1319+
eDestDims[axis] = 1;
1320+
1321+
if (getTensor(batch)->getType().isQuantizedType()) {
1322+
auto destTy = dest->getType();
1323+
auto batchTy = batch->getType();
13111324

13121325
float destScale = destTy->getScale();
13131326
float batchScale = batchTy->getScale();
13141327

13151328
int32_t destOffset = destTy->getOffset();
13161329
int32_t batchOffset = batchTy->getOffset();
13171330

1318-
auto bdim = flattenCdr(batch.dims());
1319-
1320-
// The following loop order is inefficient but easy to implement correctly;
1321-
// as this is the Interpreter, we prioritize simplicity and correctness
1322-
// above all else.
1323-
// For each element in the slice:
1324-
for (size_t i = 0; i < bdim.second; i++) {
1325-
float sum = 0.0;
1326-
1327-
// For each layer in the batch:
1328-
for (size_t n = 0; n < bdim.first; n++) {
1329-
size_t base = batch.getElementPtr({n});
1330-
sum += batch.raw(base + i) - batchOffset;
1331-
}
1331+
// Get unowned handles of the batch and dest with these new expanded dims.
1332+
auto eBatch = getTensor(batch)->getUnowned(eBatchDims);
1333+
auto eDest = getTensor(dest)->getUnowned(eDestDims);
1334+
auto eBatchH = eBatch.getHandle<int8_t>();
1335+
auto eDestH = eDest.getHandle<int8_t>();
1336+
eDestH.clear();
1337+
1338+
// For quantization, we must accumulate in the inner-most loop into a local
1339+
// float and then clip the result back into the dest tensor. Here are the
1340+
// max_tensor_dimensions cases for this, to ensure the axis is used as the
1341+
// inner-most loop.
1342+
switch (axis) {
1343+
#define LOOP_AXIS_CASE(_D0, _D1, _D2, _D3, _D4, _D5_AXIS) \
1344+
case _D5_AXIS: \
1345+
for (size_t i##_D0 = 0; i##_D0 < eBatchDims[_D0]; i##_D0++) \
1346+
for (size_t i##_D1 = 0; i##_D1 < eBatchDims[_D1]; i##_D1++) \
1347+
for (size_t i##_D2 = 0; i##_D2 < eBatchDims[_D2]; i##_D2++) \
1348+
for (size_t i##_D3 = 0; i##_D3 < eBatchDims[_D3]; i##_D3++) \
1349+
for (size_t i##_D4 = 0; i##_D4 < eBatchDims[_D4]; i##_D4++) { \
1350+
float sum = 0.0; \
1351+
for (size_t i##_D5_AXIS = 0; i##_D5_AXIS < eBatchDims[_D5_AXIS]; \
1352+
i##_D5_AXIS++) { \
1353+
sum += eBatchH.at({i0, i1, i2, i3, i4, i5}) - batchOffset; \
1354+
} \
1355+
size_t i##_D5_AXIS = 0; \
1356+
int32_t res = \
1357+
std::round(sum * batchScale / destScale) + destOffset; \
1358+
eDestH.at({i0, i1, i2, i3, i4, i5}) = \
1359+
quantization::clip<int32_t, int8_t>(res); \
1360+
} \
1361+
return;
13321362

1333-
int32_t q = std::round(sum * batchScale / destScale) + destOffset;
1334-
dest.raw(i) = quantization::clip<int32_t, int8_t>(q);
1363+
// Each loop order, with the inner-most dimension/index equal to the axis.
1364+
LOOP_AXIS_CASE(1, 2, 3, 4, 5, 0);
1365+
LOOP_AXIS_CASE(0, 2, 3, 4, 5, 1);
1366+
LOOP_AXIS_CASE(0, 1, 3, 4, 5, 2);
1367+
LOOP_AXIS_CASE(0, 1, 2, 4, 5, 3);
1368+
LOOP_AXIS_CASE(0, 1, 2, 3, 5, 4);
1369+
LOOP_AXIS_CASE(0, 1, 2, 3, 4, 5);
1370+
#undef LOOP_AXIS_CASE
1371+
default:
1372+
llvm_unreachable("Axis should be less than max_tensor_dimensions.");
13351373
}
1336-
return;
13371374
}
13381375

1339-
auto batch = getWeightHandle(I->getBatch());
1340-
auto dest = getWeightHandle(I->getDest());
1341-
1342-
auto bdim = flattenCdr(batch.dims());
1343-
1344-
dest.clear();
1345-
1346-
// For each layer in the batch:
1347-
for (size_t n = 0; n < bdim.first; n++) {
1348-
size_t base = batch.getElementPtr({n});
1349-
1350-
// For each element in the slice:
1351-
for (size_t i = 0; i < bdim.second; i++) {
1352-
dest.raw(i) += batch.raw(base + i);
1376+
// Get unowned handles of the batch and dest with these new expanded dims.
1377+
auto eBatch = getTensor(batch)->getUnowned(eBatchDims);
1378+
auto eDest = getTensor(dest)->getUnowned(eDestDims);
1379+
auto eBatchH = eBatch.getHandle();
1380+
auto eDestH = eDest.getHandle();
1381+
eDestH.clear();
1382+
1383+
// We can use this loop for all shapes. Use the same indices for both the
1384+
// batch and dest, except for setting the axis index in the dest to 0.
1385+
for (size_t x = 0; x < eBatchDims[0]; x++) {
1386+
for (size_t y = 0; y < eBatchDims[1]; y++) {
1387+
for (size_t z = 0; z < eBatchDims[2]; z++) {
1388+
for (size_t w = 0; w < eBatchDims[3]; w++) {
1389+
for (size_t q = 0; q < eBatchDims[4]; q++) {
1390+
for (size_t r = 0; r < eBatchDims[5]; r++) {
1391+
size_t destIndices[] = {x, y, z, w, q, r};
1392+
destIndices[axis] = 0;
1393+
eDestH.at(destIndices) += eBatchH.at({x, y, z, w, q, r});
1394+
}
1395+
}
1396+
}
1397+
}
13531398
}
13541399
}
13551400
}

lib/Backends/OpenCL/OpenCL.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,8 @@ void OCLBackend::doForwardPass() {
739739
}
740740

741741
if (auto *BRA = dyn_cast<BatchedReduceAddInst>(&I)) {
742+
assert(BRA->getAxis() == 0 && "No current support for non-zero axis.");
743+
742744
cl_kernel kernel = createKernel(kernelName);
743745
setKernelArg(kernel, 0, deviceBuffer_);
744746

lib/Base/Tensor.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,11 @@ void glow::genericTranspose(Tensor *src, Tensor *dest,
335335
}
336336
}
337337
}
338+
339+
ShapeVector glow::expandDimsToMax(llvm::ArrayRef<size_t> currDims) {
340+
ShapeVector newDims(currDims.begin(), currDims.end());
341+
for (size_t i = newDims.size(); i < max_tensor_dimensions; i++) {
342+
newDims.push_back(1);
343+
}
344+
return newDims;
345+
}

lib/Graph/Graph.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -982,18 +982,28 @@ MatMulNode *Function::createMatMul(llvm::StringRef name, NodeValue lhs,
982982

983983
BatchedReduceAddNode *Function::createBatchedReduceAdd(llvm::StringRef name,
984984
TypeRef outTy,
985-
NodeValue batch) {
986-
assert(outTy->size() == flattenCdr(batch.dims()).second);
985+
NodeValue batch,
986+
size_t axis) {
987+
// Calculate the expected total number of elements in the output tensor based
988+
// on the number of elements in the batch divided by the axis dimension.
989+
const size_t outNumElements = batch.getType()->size() / batch.dims()[axis];
990+
(void)outNumElements;
991+
assert(outTy->size() == outNumElements &&
992+
"Incorrect number of elements in the output type.");
987993
auto OT = getParent()->uniqueType(*outTy);
988-
return addNode(new BatchedReduceAddNode(name, OT, batch));
994+
return addNode(new BatchedReduceAddNode(name, OT, batch, axis));
989995
}
990996

991997
BatchedReduceAddNode *Function::createBatchedReduceAdd(llvm::StringRef name,
992-
NodeValue batch) {
998+
NodeValue batch,
999+
size_t axis) {
9931000
auto BT = batch.getType();
994-
auto OT =
995-
getParent()->uniqueType(BT->getElementType(), BT->dims().drop_front());
996-
return createBatchedReduceAdd(name, OT, batch);
1001+
1002+
ShapeVector outDims(BT->dims().begin(), BT->dims().end());
1003+
outDims.erase(outDims.begin() + axis);
1004+
1005+
auto OT = getParent()->uniqueType(BT->getElementType(), outDims);
1006+
return createBatchedReduceAdd(name, OT, batch, axis);
9971007
}
9981008

9991009
BatchedAddNode *Function::createBatchedAdd(llvm::StringRef name,

0 commit comments

Comments
 (0)