Skip to content

Commit b975212

Browse files
authored
[New Operator] FusedRowwiseQuantizedSparseLengthsWeightedSumNode (#2368)
*Description*: As noted in #2292, we decided to implement both fused and unfused versions of rowwise-quantized SLWS. *Testing*: Added OperatorTests and Caffe2ImporterTests. *Documentation*: Added. Closes #1698
1 parent 8b08f35 commit b975212

22 files changed

+1003
-68
lines changed

docs/Quantization.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,16 @@ Row-wise quantized SparseLengthsWeightedSum is also supported. Similar to the
212212
above, we compute scales and offsets per row, to be used with the `Data` input
213213
for the `RowwiseQuantizedSparseLengthsSumNode`. Scales and Offsets are inputs to
214214
the node. Output of this node is float, matching the Caffe2 implementation.
215+
216+
### Fused Row-wise Quantization
217+
218+
For some backends it may be beneficial to keep each row's scales and offsets
219+
fused inline with the data. Caffe2 implements nodes with fused storage, such as
220+
[SparseLengthsWeightedSum](https://caffe2.ai/docs/operators-catalogue.html#sparselengthsweightedsumfused8bitrowwise). Glow
221+
supports such fused Nodes/Instructions, for example
222+
`FusedRowwiseQuantizedSparseLengthsWeightedSum`. The `ElemKind` of fused tensors
223+
is `Int8FusedQTy`. Tensors with `Int8FusedQTy` are 2-dimensional, and have an
224+
extra 8 columns for each row. The first extra 4 bytes are the float scale of the
225+
row, and the second extra 4 bytes are the in32_t offset. Note that similar to
226+
normal row-wise quantized tensors, they use a dummy scale and offset in the
227+
Type.

include/glow/Base/Tensor.h

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,18 @@ class Tensor final {
107107
auto *data = reinterpret_cast<int32_t *>(getData());
108108
std::fill(&data[0], &data[0] + size(), (int32_t)type_.getOffset());
109109
} break;
110+
case ElemKind::Int8FusedQTy: {
111+
assert(dims().size() == 2 && "Fused tensor must be 2-dimensional.");
112+
assert(dims()[1] > 8 && "Fused tensor must have more than 8 columns.");
113+
const size_t width = dims()[1];
114+
auto *data = reinterpret_cast<int8_t *>(getData());
115+
for (size_t i = 0, e = dims()[0]; i < e; i++) {
116+
int8_t *scaleOffsetPtr = &data[(i + 1) * width] - 8;
117+
int32_t offset;
118+
memcpy(&offset, scaleOffsetPtr + 4, 4);
119+
std::fill(&data[i * width], scaleOffsetPtr, (int8_t)offset);
120+
}
121+
} break;
110122
default:
111123
// Non-quantized tensors are set to 0.
112124
std::fill(&getData()[0], &getData()[0] + size() * type_.getElementSize(),
@@ -174,8 +186,9 @@ class Tensor final {
174186
Tensor &operator=(const Tensor &other) = delete;
175187

176188
/// Initialize the content of the tensor using the \p init method. The value
177-
/// \p val is the initialization parameter. \p PRNG is used to generate
178-
/// random numbers.
189+
/// \p val is the initialization parameter. \p PRNG is used to generate random
190+
/// numbers. Note that if the tensor's kind is Int8FusedQTy, then the fused
191+
/// scaled/offsets will not be modified.
179192
void init(InitKind init, float val, PseudoRNG &PRNG);
180193

181194
/// \returns unowned tensor using the same data buffer as the current tensor
@@ -288,6 +301,17 @@ class Tensor final {
288301
return false;
289302
}
290303

304+
// For now, make sure that either both or neither of the tensors have
305+
// Int8FusedQTy. While it is possible for an Int8QTy tensor to equal a
306+
// Int8FusedQTy tensor if the Int8FusedQTy tensor has the same scale/offset
307+
// on all of its rows, and that scale/offset match that of the Int8QTy, we
308+
// do not support checking this for now.
309+
assert(((getElementType() == ElemKind::Int8FusedQTy &&
310+
other.getElementType() == ElemKind::Int8FusedQTy) ||
311+
(getElementType() != ElemKind::Int8FusedQTy &&
312+
other.getElementType() != ElemKind::Int8FusedQTy)) &&
313+
"Int8FusedQTy only supports comparing against same ElemKind.");
314+
291315
switch (getElementType()) {
292316
case ElemKind::FloatTy:
293317
return isEqualImpl<float>(other, allowedError);
@@ -315,6 +339,11 @@ class Tensor final {
315339
return isEqualImpl<int32_t>(other, allowedError);
316340
case ElemKind::Int64ITy:
317341
return isEqualImpl<int64_t>(other, allowedError);
342+
// Note: We can use isEqualImpl() here because the scales/offsets will be
343+
// compared as if they were data, so we will return false if any rowwise
344+
// scale/offset do not match.
345+
case ElemKind::Int8FusedQTy:
346+
return isEqualImpl<int8_t>(other, allowedError);
318347
}
319348

320349
// This is to make compiler happy. It can never reach this point as switch
@@ -701,8 +730,23 @@ template <class ElemTy> class Handle final {
701730
assert(filterSize > 0 && "invalid filter size");
702731
double scale = std::sqrt(3.0 / double(filterSize));
703732
std::uniform_real_distribution<> dist(-scale, scale);
704-
for (auto &e : *this) {
705-
e = dist(PRNG);
733+
switch (getElementType()) {
734+
default: {
735+
for (auto &e : *this) {
736+
e = dist(PRNG);
737+
}
738+
return;
739+
}
740+
case ElemKind::Int8FusedQTy: {
741+
assert(dims().size() == 2 && "Fused tensor must be 2-dimensional.");
742+
assert(dims()[1] > 8 && "Fused tensor must have more than 8 columns.");
743+
for (size_t i = 0, e = dims()[0]; i < e; i++) {
744+
for (size_t j = 0, f = dims()[1] - 8; j < f; j++) {
745+
at({i, j}) = dist(PRNG);
746+
}
747+
}
748+
return;
749+
}
706750
}
707751
}
708752

include/glow/Base/Type.h

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,14 @@ inline bool operator==(const ShapeNCHW &LHS, const ShapeNCHW &RHS) {
185185
/// An enum representing the type used by the elements of a tensor. The types of
186186
/// Handles for these tensors should match the element kind.
187187
enum class ElemKind : unsigned char {
188-
FloatTy, // 32-bit float type (float)
189-
Float16Ty, // 16-bit float type (half, fp16)
190-
Int8QTy, // 8-bit quantized type (int8_t)
191-
Int16QTy, // 16-bit quantized type (int16_t)
192-
Int32QTy, // 32-bit quantized type (int32_t)
193-
Int32ITy, // 32-bit index type (int32_t)
194-
Int64ITy, // 64-bit index type (int64_t)
188+
FloatTy, // 32-bit float type (float)
189+
Float16Ty, // 16-bit float type (half, fp16)
190+
Int8QTy, // 8-bit quantized type (int8_t)
191+
Int16QTy, // 16-bit quantized type (int16_t)
192+
Int32QTy, // 32-bit quantized type (int32_t)
193+
Int32ITy, // 32-bit index type (int32_t)
194+
Int64ITy, // 64-bit index type (int64_t)
195+
Int8FusedQTy, // 8-bit quantized type with fused scale/offset (int8_t)
195196
};
196197

197198
/// A class that represents a type of a tensor.
@@ -360,6 +361,8 @@ struct Type final {
360361
return std::is_same<ElemTy, int32_t>::value;
361362
case ElemKind::Int64ITy:
362363
return std::is_same<ElemTy, int64_t>::value;
364+
case ElemKind::Int8FusedQTy:
365+
return std::is_same<ElemTy, int8_t>::value;
363366
}
364367
GLOW_UNREACHABLE("Invalid type.");
365368
}
@@ -368,7 +371,8 @@ struct Type final {
368371
bool isQuantizedType() const {
369372
return elementType_ == ElemKind::Int8QTy ||
370373
elementType_ == ElemKind::Int16QTy ||
371-
elementType_ == ElemKind::Int32QTy;
374+
elementType_ == ElemKind::Int32QTy ||
375+
elementType_ == ElemKind::Int8FusedQTy;
372376
}
373377

374378
/// \returns true if the type of this Tensor is one of the floating point
@@ -401,6 +405,8 @@ struct Type final {
401405
return sizeof(int32_t);
402406
case ElemKind::Int64ITy:
403407
return sizeof(int64_t);
408+
case ElemKind::Int8FusedQTy:
409+
return sizeof(int8_t);
404410
}
405411
GLOW_UNREACHABLE("Invalid type.");
406412
}
@@ -413,7 +419,7 @@ struct Type final {
413419
/// \return the textual name of the element \p Ty.
414420
static llvm::StringRef getElementName(ElemKind Ty) {
415421
static const char *names[] = {
416-
"float", "float16", "i8", "i16", "i32", "index32", "index64",
422+
"float", "float16", "i8", "i16", "i32", "index32", "index64", "i8fused",
417423
};
418424
return names[(int)Ty];
419425
}

include/glow/Graph/Graph.h

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -573,18 +573,19 @@ class Function final : public Named {
573573
NodeValue data, NodeValue weights,
574574
NodeValue indices, NodeValue lengths);
575575

576-
/// Create a node, performing SparseLengthsSum operation, using rowwise
577-
/// quantization for the input data. Gathers slices of the outer-most
578-
/// dimension of Data indexed by Indices vector, and then accumulates them
579-
/// into len(Lengths) entries: first Lengths[0] slices are aggregated to
580-
/// Result[0], next Lengths[1] slices are aggregated to Result[1],
581-
/// etc. I.e. sum(Lengths) must be equal to len(Indices).
576+
/// Creates and \returns a node of \p name, performing the SparseLengthsSum
577+
/// operation, using rowwise quantization for the input \p data with the \p
578+
/// scales and \p offsets as separate input tensors. Gathers slices of the
579+
/// outer-most dimension of data indexed by the \p indices vector, and then
580+
/// accumulates them into len(\p lengths) entries: first Lengths[0] slices are
581+
/// aggregated to Result[0], next Lengths[1] slices are aggregated to
582+
/// Result[1], etc. I.e. sum(Lengths) must be equal to len(Indices).
582583
RowwiseQuantizedSparseLengthsWeightedSumNode *
583-
createRowwiseQuantizedSparseLengthsWeightedSum(
584-
llvm::StringRef name, Constant *data, Constant *scales, Constant *offsets,
585-
NodeValue weights, NodeValue indices, NodeValue lengths);
584+
createRowwiseQuantizedSparseLengthsSum(llvm::StringRef name, Constant *data,
585+
Constant *scales, Constant *offsets,
586+
NodeValue indices, NodeValue lengths);
586587

587-
/// Same as \ref createRowwiseQuantizedSparseLengthsWeightedSum(), but expects
588+
/// Same as \ref createRowwiseQuantizedSparseLengthsSum(), but expects
588589
/// float input \p data, which is rowwise-quantized internally.
589590
RowwiseQuantizedSparseLengthsWeightedSumNode *
590591
createRowwiseQuantizedSparseLengthsSum(llvm::StringRef name, Tensor &data,
@@ -593,11 +594,11 @@ class Function final : public Named {
593594
/// Same as \ref createRowwiseQuantizedSparseLengthsSum(), but i-th slice is
594595
/// multiplied by weights[i]. len(weights) must be equal to len(indices).
595596
RowwiseQuantizedSparseLengthsWeightedSumNode *
596-
createRowwiseQuantizedSparseLengthsSum(llvm::StringRef name, Constant *data,
597-
Constant *scales, Constant *offsets,
598-
NodeValue indices, NodeValue lengths);
597+
createRowwiseQuantizedSparseLengthsWeightedSum(
598+
llvm::StringRef name, Constant *data, Constant *scales, Constant *offsets,
599+
NodeValue weights, NodeValue indices, NodeValue lengths);
599600

600-
/// Same as \ref createRowwiseQuantizedSparseLengthsSum(), but expects
601+
/// Same as \ref createRowwiseQuantizedSparseLengthsWeightedSum(), but expects
601602
/// float input \p data, which is rowwise-quantized internally.
602603
RowwiseQuantizedSparseLengthsWeightedSumNode *
603604
createRowwiseQuantizedSparseLengthsWeightedSum(llvm::StringRef name,
@@ -606,6 +607,45 @@ class Function final : public Named {
606607
NodeValue indices,
607608
NodeValue lengths);
608609

610+
/// Creates and \returns a node of \p name, performing the SparseLengthsSum
611+
/// operation, using fused rowwise quantization for the input \p data wherein
612+
/// the scales and offsets are fused inline with each row of data. \p data
613+
/// must be ElemKind::Int8FusedQTy. Gathers slices of the outer-most dimension
614+
/// of data indexed by the \p indices vector, and then accumulates them into
615+
/// len(\p lengths) entries: first Lengths[0] slices are aggregated to
616+
/// Result[0], next Lengths[1] slices are aggregated to Result[1],
617+
/// etc. I.e. sum(Lengths) must be equal to len(Indices).
618+
FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
619+
createFusedRowwiseQuantizedSparseLengthsSum(llvm::StringRef name,
620+
Constant *data, NodeValue indices,
621+
NodeValue lengths);
622+
623+
/// Same as \ref createFusedRowwiseQuantizedSparseLengthsSum(), but expects
624+
/// float input \p data, which is rowwise-quantized and fused internally.
625+
FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
626+
createFusedRowwiseQuantizedSparseLengthsSum(llvm::StringRef name,
627+
Tensor &data, NodeValue indices,
628+
NodeValue lengths);
629+
630+
/// Same as \ref createFusedRowwiseQuantizedSparseLengthsSum(), but i-th slice
631+
/// is multiplied by weights[i]. len(weights) must be equal to len(indices).
632+
FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
633+
createFusedRowwiseQuantizedSparseLengthsWeightedSum(llvm::StringRef name,
634+
Tensor &data,
635+
NodeValue weights,
636+
NodeValue indices,
637+
NodeValue lengths);
638+
639+
/// Same as \ref createFusedRowwiseQuantizedSparseLengthsWeightedSum(), but
640+
/// expects float input \p data, which is rowwise-quantized and fused
641+
/// internally.
642+
FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
643+
createFusedRowwiseQuantizedSparseLengthsWeightedSum(llvm::StringRef name,
644+
Constant *data,
645+
NodeValue weights,
646+
NodeValue indices,
647+
NodeValue lengths);
648+
609649
/// Given a vector of segment lengths, calculates offsets of each segment and
610650
/// packs them next to the lengths. For the input vector of length N the
611651
/// output is a Nx2 matrix with (offset, lengths) packaged for each segment.

include/glow/Quantization/Base/Base.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,15 @@ std::vector<int8_t> createMapping(TypeRef inTy, TypeRef outTy,
136136
void tensorRowwiseQuantization(const Tensor &input, Tensor &output,
137137
Tensor &scales, Tensor &offsets);
138138

139+
/// Fused-rowwise quantize the tensor \p input. Scales and offsets are generated
140+
/// from each row of \p input. \p output is tensor of the same shape as input
141+
/// but with 8 extra columns for storing fused scales (4 bytes (columns) for
142+
/// float) and offset (4 bytes (columns) for int32_t).
143+
/// \pre input.dims().size() == 2
144+
/// \pre output.dims().size() == 2
145+
/// \pre input.dims()[1] + 8 == output.dims()[1]
146+
void tensorFusedRowwiseQuantization(const Tensor &input, Tensor &output);
147+
139148
} // namespace quantization
140149
} // namespace glow
141150

lib/Backends/CPU/LLVMIRGen.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ llvm::Type *LLVMIRGen::getElementType(llvm::IRBuilder<> &builder,
237237
return builder.getInt32Ty();
238238
case ElemKind::Int32ITy:
239239
return builder.getInt32Ty();
240+
case ElemKind::Int8FusedQTy:
241+
return builder.getInt8Ty();
240242
}
241243
return nullptr;
242244
}
@@ -324,6 +326,9 @@ llvm::Value *LLVMIRGen::emitValueAddress(llvm::IRBuilder<> &builder,
324326
case ElemKind::Int32ITy:
325327
T = llvm::Type::getInt32PtrTy(ctx_);
326328
break;
329+
case ElemKind::Int8FusedQTy:
330+
T = llvm::Type::getInt8PtrTy(ctx_);
331+
break;
327332
default:
328333
llvm_unreachable("Unimplemented");
329334
break;
@@ -469,6 +474,8 @@ llvm::Value *LLVMIRGen::emitConst(llvm::IRBuilder<> &builder, float val,
469474
return builder.getInt32(static_cast<int32_t>(val));
470475
case ElemKind::Int32ITy:
471476
return builder.getInt32(static_cast<int32_t>(val));
477+
case ElemKind::Int8FusedQTy:
478+
return builder.getInt8(static_cast<int8_t>(val));
472479
}
473480
llvm_unreachable("Unknown element type");
474481
}
@@ -2318,6 +2325,29 @@ void LLVMIRGen::generateLLVMIRForInstr(llvm::IRBuilder<> &builder,
23182325
break;
23192326
}
23202327

2328+
case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumInstKind: {
2329+
auto *N = cast<FusedRowwiseQuantizedSparseLengthsWeightedSumInst>(I);
2330+
auto *dest = N->getDest();
2331+
auto *data = N->getData();
2332+
auto *weights = N->getWeights();
2333+
auto *indices = N->getIndices();
2334+
auto *lengths = N->getLengths();
2335+
auto *destPtr = emitValueAddress(builder, dest);
2336+
auto *dataPtr = emitValueAddress(builder, data);
2337+
auto *weightsPtr = emitValueAddress(builder, weights);
2338+
auto *indicesPtr = emitValueAddress(builder, indices);
2339+
auto *lengthsPtr = emitValueAddress(builder, lengths);
2340+
auto *segments = emitConstSizeT(builder, lengths->dims()[0]);
2341+
auto *inLineSize = emitConstSizeT(builder, data->size() / data->dims()[0]);
2342+
auto *outLineSize = emitConstSizeT(builder, dest->size() / dest->dims()[0]);
2343+
auto *F = getFunction("fused_rowwise_quantized_sparse_lengths_weighted_sum",
2344+
dest->getElementType());
2345+
createCall(builder, F,
2346+
{destPtr, dataPtr, weightsPtr, indicesPtr, lengthsPtr, segments,
2347+
inLineSize, outLineSize});
2348+
break;
2349+
}
2350+
23212351
case Kinded::Kind::SparseToDenseInstKind: {
23222352
auto *STDI = llvm::cast<SparseToDenseInst>(I);
23232353
auto *indices = STDI->getIndices();

lib/Backends/CPU/libjit/libjit.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,30 @@ void libjit_rowwise_quantized_sparse_lengths_weighted_sum_f(
10571057
}
10581058
}
10591059

1060+
void libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f(
1061+
float *dest, int8_t *data, float *weights, size_t *indices,
1062+
int32_t *lengths, size_t segments, size_t inLineSize, size_t outLineSize) {
1063+
memset(dest, 0, segments * outLineSize * sizeof(float));
1064+
size_t curIndex = 0;
1065+
for (size_t i = 0; i < segments; i++) {
1066+
for (int32_t j = 0, e = lengths[i]; j < e; j++) {
1067+
const float weight = weights[curIndex];
1068+
const size_t line = indices[curIndex];
1069+
const int8_t *currRowScaleOffsetPtr =
1070+
data + ((line + 1) * inLineSize) - 8;
1071+
float scale;
1072+
int32_t offset;
1073+
memcpy(&scale, currRowScaleOffsetPtr, sizeof(float));
1074+
memcpy(&offset, currRowScaleOffsetPtr + 4, sizeof(int32_t));
1075+
for (size_t k = 0; k < outLineSize; k++) {
1076+
const float fData = scale * (data[line * inLineSize + k] - offset);
1077+
dest[i * outLineSize + k] += weight * fData;
1078+
}
1079+
curIndex++;
1080+
}
1081+
}
1082+
}
1083+
10601084
void libjit_sparse_to_dense_f(float *dest, const size_t *indices,
10611085
const float *values, size_t numIndices,
10621086
size_t destSize, size_t valueSize) {

0 commit comments

Comments
 (0)