Skip to content

Commit c30905a

Browse files
committed
[Importer] Add C2 importer support for RWQ SLWS/SLS
1 parent b1668cd commit c30905a

8 files changed

+433
-7
lines changed

include/glow/Graph/Graph.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,12 +580,26 @@ class Function final : public Named {
580580
/// Result[0], next Lengths[1] slices are aggregated to Result[1],
581581
/// etc. I.e. sum(Lengths) must be equal to len(Indices).
582582
RowwiseQuantizedSparseLengthsWeightedSumNode *
583+
createRowwiseQuantizedSparseLengthsWeightedSum(
584+
llvm::StringRef name, Constant *data, Constant *scales, Constant *offsets,
585+
NodeValue weights, NodeValue indices, NodeValue lengths);
586+
587+
/// Same as \ref createRowwiseQuantizedSparseLengthsWeightedSum(), but expects
588+
/// float input \p data, which is rowwise-quantized internally.
589+
RowwiseQuantizedSparseLengthsWeightedSumNode *
583590
createRowwiseQuantizedSparseLengthsSum(llvm::StringRef name, Tensor &data,
584591
NodeValue indices, NodeValue lengths);
585592

586593
/// Same as \ref createRowwiseQuantizedSparseLengthsSum(), but i-th slice is
587594
/// multiplied by weights[i]. len(weights) must be equal to len(indices).
588595
RowwiseQuantizedSparseLengthsWeightedSumNode *
596+
createRowwiseQuantizedSparseLengthsSum(llvm::StringRef name, Constant *data,
597+
Constant *scales, Constant *offsets,
598+
NodeValue indices, NodeValue lengths);
599+
600+
/// Same as \ref createRowwiseQuantizedSparseLengthsSum(), but expects
601+
/// float input \p data, which is rowwise-quantized internally.
602+
RowwiseQuantizedSparseLengthsWeightedSumNode *
589603
createRowwiseQuantizedSparseLengthsWeightedSum(llvm::StringRef name,
590604
Tensor &data,
591605
NodeValue weights,

lib/Graph/Graph.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,6 +1376,28 @@ Function::createSparseLengthsWeightedSum(llvm::StringRef name, TypeRef outTy,
13761376
indices, lengths));
13771377
}
13781378

1379+
RowwiseQuantizedSparseLengthsWeightedSumNode *
1380+
Function::createRowwiseQuantizedSparseLengthsWeightedSum(
1381+
llvm::StringRef name, Constant *data, Constant *scales, Constant *offsets,
1382+
NodeValue weights, NodeValue indices, NodeValue lengths) {
1383+
auto inDims = data->dims();
1384+
ShapeVector outDims(inDims.begin(), inDims.end());
1385+
outDims[0] = lengths.dims()[0];
1386+
auto outTy = getParent()->uniqueType(ElemKind::FloatTy, outDims);
1387+
return addNode(new RowwiseQuantizedSparseLengthsWeightedSumNode(
1388+
name, outTy, data, scales, offsets, weights, indices, lengths));
1389+
}
1390+
1391+
RowwiseQuantizedSparseLengthsWeightedSumNode *
1392+
Function::createRowwiseQuantizedSparseLengthsSum(
1393+
llvm::StringRef name, Constant *data, Constant *scales, Constant *offsets,
1394+
NodeValue indices, NodeValue lengths) {
1395+
auto ty = getParent()->uniqueType(ElemKind::FloatTy, {indices.dims()[0]});
1396+
auto ones = createSplat(name.str() + ".ones", ty, 1.0);
1397+
return createRowwiseQuantizedSparseLengthsWeightedSum(
1398+
name, data, scales, offsets, ones, indices, lengths);
1399+
}
1400+
13791401
/// Helper to create a RowwiseQuantizedSparseLengthsWeightedSumNode in the
13801402
/// Function \p F with \p name, using \ data, \p weights, \p indices, and \p
13811403
/// lengths as inputs. The provided float data in \p Tensor is rowwise
@@ -1386,9 +1408,6 @@ quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum(
13861408
Function *F, llvm::StringRef name, Tensor &data, NodeValue weights,
13871409
NodeValue indices, NodeValue lengths) {
13881410
auto inDims = data.dims();
1389-
ShapeVector outDims(inDims.begin(), inDims.end());
1390-
outDims[0] = lengths.dims()[0];
1391-
auto outTy = F->getParent()->uniqueType(ElemKind::FloatTy, outDims);
13921411

13931412
// Note: In rwqData, we are using a quantized type, however the scale/offset
13941413
// are set to dummy values 0.0/0. This is because the actually used
@@ -1403,10 +1422,8 @@ quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum(
14031422
quantization::tensorRowwiseQuantization(data, rwqData->getPayload(),
14041423
dataScales->getPayload(),
14051424
dataOffsets->getPayload());
1406-
1407-
return F->addNode(new RowwiseQuantizedSparseLengthsWeightedSumNode(
1408-
name, outTy, rwqData, dataScales, dataOffsets, weights, indices,
1409-
lengths));
1425+
return F->createRowwiseQuantizedSparseLengthsWeightedSum(
1426+
name, rwqData, dataScales, dataOffsets, weights, indices, lengths);
14101427
}
14111428

14121429
RowwiseQuantizedSparseLengthsWeightedSumNode *

lib/Importer/Caffe2ModelLoader.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,80 @@ llvm::Error Caffe2ModelLoader::loadOperator(const caffe2::OperatorDef &op) {
870870
return llvm::Error::success();
871871
}
872872

873+
if (typeName == "SparseLengthsWeightedSum8BitsRowwise" ||
874+
typeName == "SparseLengthsSum8BitsRowwise") {
875+
// If SparseLengthsWeightedSum8BitsRowwise, then the weights are the second
876+
// input and so we need to shift indices/lengths/scalesBiases.
877+
size_t indicesIdx = 1;
878+
size_t lengthsIdx = 2;
879+
size_t scalesBiasesIdx = 3;
880+
if (typeName == "SparseLengthsWeightedSum8BitsRowwise") {
881+
indicesIdx++;
882+
lengthsIdx++;
883+
scalesBiasesIdx++;
884+
}
885+
886+
NodeValue data;
887+
ASSIGN_VALUE_OR_RETURN_ERR(data,
888+
getNodeValueOrCreateConstantByName(op.input(0)));
889+
NodeValue indices;
890+
ASSIGN_VALUE_OR_RETURN_ERR(
891+
indices, getNodeValueOrCreateConstantByName(op.input(indicesIdx)));
892+
NodeValue lengths;
893+
ASSIGN_VALUE_OR_RETURN_ERR(
894+
lengths, getNodeValueOrCreateConstantByName(op.input(lengthsIdx)));
895+
NodeValue scalesBiases;
896+
ASSIGN_VALUE_OR_RETURN_ERR(scalesBiases, getNodeValueOrCreateConstantByName(
897+
op.input(scalesBiasesIdx)));
898+
899+
Constant *scalesBiasesC = llvm::dyn_cast<Constant>(scalesBiases);
900+
RETURN_ERR_IF_NOT(scalesBiasesC, "scales_biases must be Constant.");
901+
Constant *dataC = llvm::dyn_cast<Constant>(data);
902+
RETURN_ERR_IF_NOT(dataC->getElementType() == ElemKind::Int8QTy,
903+
"Data must be Int8QTy.");
904+
905+
const size_t numRows = data.dims()[0];
906+
907+
// Make sure all the shapes make sense.
908+
RETURN_ERR_IF_NOT(lengths.dims().size() == 1, "lengths must be a vector.");
909+
RETURN_ERR_IF_NOT(indices.dims().size() == 1, "indices must be a vector.");
910+
RETURN_ERR_IF_NOT(scalesBiases.dims().size() == 2,
911+
"scale_bias has to be a matrix.");
912+
RETURN_ERR_IF_NOT(scalesBiases.dims()[0] == numRows,
913+
"scale_bias must have the same number of rows as data.");
914+
RETURN_ERR_IF_NOT(scalesBiases.dims()[1] == 2,
915+
"Second dim of scale_bias has to be equal to 2.");
916+
917+
// Now strip out the scales and biases into their own tensors.
918+
Constant *dataScales = G_.getParent()->createConstant(
919+
ElemKind::FloatTy, {numRows}, "dataScales");
920+
Constant *dataOffsets = G_.getParent()->createConstant(
921+
ElemKind::Int32ITy, {numRows}, "dataOffsets");
922+
923+
auto dataScalesH = dataScales->getHandle<float>();
924+
auto dataOffsetsH = dataOffsets->getHandle<int32_t>();
925+
auto scalesBiasesH = scalesBiasesC->getHandle<float>();
926+
for (size_t i = 0, e = numRows; i < e; i++) {
927+
dataScalesH.at({i}) = scalesBiasesH.at({i, 0});
928+
// Caffe2 represents offsets (bias) using float, while Glow uses int32_t.
929+
dataOffsetsH.at({i}) = static_cast<int32_t>(scalesBiasesH.at({i, 1}));
930+
}
931+
932+
Node *node;
933+
if (typeName == "SparseLengthsWeightedSum8BitsRowwise") {
934+
NodeValue weights;
935+
ASSIGN_VALUE_OR_RETURN_ERR(
936+
weights, getNodeValueOrCreateConstantByName(op.input(1)));
937+
node = G_.createRowwiseQuantizedSparseLengthsWeightedSum(
938+
opName, dataC, dataScales, dataOffsets, weights, indices, lengths);
939+
} else {
940+
node = G_.createRowwiseQuantizedSparseLengthsSum(
941+
opName, dataC, dataScales, dataOffsets, indices, lengths);
942+
}
943+
RETURN_IF_ERR(addNodeAsOutput(op, node));
944+
return llvm::Error::success();
945+
}
946+
873947
RETURN_ERR(unexpectedNodeErrorMessage(op, "Unsupported operator."));
874948
}
875949

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: "rowwise_quantized_sparse_lengths_sum_init_net_test"
2+
op {
3+
output: "data"
4+
type: "Int8GivenTensorFill"
5+
arg {
6+
name: "shape"
7+
ints: 3
8+
ints: 2
9+
}
10+
arg {
11+
name: "values"
12+
s: "\324\377\254\377\311\377"
13+
}
14+
arg {
15+
name: "Y_zero_point"
16+
i: 0
17+
}
18+
arg {
19+
name: "Y_scale"
20+
f: 0.0
21+
}
22+
}
23+
op {
24+
output: "scales_bias"
25+
type: "GivenTensorFill"
26+
arg {
27+
name: "shape"
28+
ints: 3
29+
ints: 2
30+
}
31+
arg {
32+
name: "values"
33+
floats: 0.004706
34+
floats: -128.0
35+
floats: 0.013333
36+
floats: -128.0
37+
floats: 0.022353
38+
floats: -128.0
39+
}
40+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
name: "rowwise_quantized_sparse_lengths_sum_predict_net_test"
2+
op {
3+
input: "data"
4+
input: "indices"
5+
input: "lengths"
6+
input: "scales_bias"
7+
output: "result"
8+
name: ""
9+
type: "SparseLengthsSum8BitsRowwise"
10+
}
11+
external_input: "indices"
12+
external_input: "lengths"
13+
external_output: "result"
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
name: "rowwise_quantized_sparse_lengths_weighted_sum_init_net_test"
2+
op {
3+
output: "data"
4+
type: "Int8GivenTensorFill"
5+
arg {
6+
name: "shape"
7+
ints: 3
8+
}
9+
arg {
10+
name: "values"
11+
s: "\377\000\377"
12+
}
13+
arg {
14+
name: "Y_zero_point"
15+
i: 0
16+
}
17+
arg {
18+
name: "Y_scale"
19+
f: 0.0
20+
}
21+
}
22+
op {
23+
output: "weights"
24+
type: "GivenTensorFill"
25+
arg {
26+
name: "shape"
27+
ints: 8
28+
}
29+
arg {
30+
name: "values"
31+
floats: 3.0
32+
floats: 1.0
33+
floats: 0.0
34+
floats: 0.0
35+
floats: 0.0
36+
floats: 0.0
37+
floats: 2.0
38+
floats: -0.5
39+
}
40+
}
41+
op {
42+
output: "scales_bias"
43+
type: "GivenTensorFill"
44+
arg {
45+
name: "shape"
46+
ints: 3
47+
ints: 2
48+
}
49+
arg {
50+
name: "values"
51+
floats: 0.007843
52+
floats: -128.0
53+
floats: 0.001961
54+
floats: 127.0
55+
floats: 0.050980
56+
floats: -128.0
57+
}
58+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
name: "rowwise_quantized_sparse_lengths_weighted_sum_predict_net_test"
2+
op {
3+
input: "data"
4+
input: "weights"
5+
input: "indices"
6+
input: "lengths"
7+
input: "scales_bias"
8+
output: "result"
9+
name: ""
10+
type: "SparseLengthsWeightedSum8BitsRowwise"
11+
}
12+
external_input: "indices"
13+
external_input: "lengths"
14+
external_output: "result"

0 commit comments

Comments
 (0)