Skip to content

Commit 3f0740f

Browse files
committed
[Tensor] Change init so it doesn't touch fused scale/offset.
1 parent a7e795f commit 3f0740f

File tree

3 files changed

+117
-5
lines changed

3 files changed

+117
-5
lines changed

include/glow/Base/Tensor.h

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,18 @@ class Tensor final {
107107
auto *data = reinterpret_cast<int32_t *>(getData());
108108
std::fill(&data[0], &data[0] + size(), (int32_t)type_.getOffset());
109109
} break;
110+
case ElemKind::Int8FusedQTy: {
111+
assert(dims().size() == 2 && "Fused tensor must be 2-dimensional.");
112+
assert(dims()[1] > 8 && "Fused tensor must have more than 8 columns.");
113+
const size_t width = dims()[1];
114+
auto *data = reinterpret_cast<int8_t *>(getData());
115+
for (size_t i = 0, e = dims()[0]; i < e; i++) {
116+
int8_t *scaleOffsetPtr = &data[(i + 1) * width] - 8;
117+
int32_t offset;
118+
memcpy(&offset, scaleOffsetPtr + 4, 4);
119+
std::fill(&data[i * width], scaleOffsetPtr, (int8_t)offset);
120+
}
121+
} break;
110122
default:
111123
// Non-quantized tensors are set to 0.
112124
std::fill(&getData()[0], &getData()[0] + size() * type_.getElementSize(),
@@ -174,8 +186,9 @@ class Tensor final {
174186
Tensor &operator=(const Tensor &other) = delete;
175187

176188
/// Initialize the content of the tensor using the \p init method. The value
177-
/// \p val is the initialization parameter. \p PRNG is used to generate
178-
/// random numbers.
189+
/// \p val is the initialization parameter. \p PRNG is used to generate random
190+
/// numbers. Note that if the tensor's kind is Int8FusedQTy, then the fused
191+
/// scaled/offsets will not be modified.
179192
void init(InitKind init, float val, PseudoRNG &PRNG);
180193

181194
/// \returns unowned tensor using the same data buffer as the current tensor
@@ -717,8 +730,23 @@ template <class ElemTy> class Handle final {
717730
assert(filterSize > 0 && "invalid filter size");
718731
double scale = std::sqrt(3.0 / double(filterSize));
719732
std::uniform_real_distribution<> dist(-scale, scale);
720-
for (auto &e : *this) {
721-
e = dist(PRNG);
733+
switch (getElementType()) {
734+
default: {
735+
for (auto &e : *this) {
736+
e = dist(PRNG);
737+
}
738+
return;
739+
}
740+
case ElemKind::Int8FusedQTy: {
741+
assert(dims().size() == 2 && "Fused tensor must be 2-dimensional.");
742+
assert(dims()[1] > 8 && "Fused tensor must have more than 8 columns.");
743+
for (size_t i = 0, e = dims()[0]; i < e; i++) {
744+
for (size_t j = 0, f = dims()[1] - 8; j < f; j++) {
745+
at({i, j}) = dist(PRNG);
746+
}
747+
}
748+
return;
749+
}
722750
}
723751
}
724752

lib/Base/Tensor.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,14 @@ void Tensor::init(InitKind init, float val, PseudoRNG &PRNG) {
423423
break;
424424
}
425425
case ElemKind::Int8FusedQTy: {
426-
getHandle<int8_t>().clear(val);
426+
assert(dims().size() == 2 && "Fused tensor must be 2-dimensional.");
427+
assert(dims()[1] > 8 && "Fused tensor must have more than 8 columns.");
428+
auto H = getHandle<int8_t>();
429+
for (size_t i = 0; i < dims()[0]; i++) {
430+
for (size_t j = 0, f = dims()[1] - 8; j < f; j++) {
431+
H.at({i, j}) = val;
432+
}
433+
}
427434
break;
428435
}
429436
}

tests/unittests/TensorsTest.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,3 +791,80 @@ TEST(Tensor, insertSlice) {
791791
3.0f, 4.0f, 0.0f, 0.0f, 0.0f, 0.0f};
792792
EXPECT_TRUE(big.isEqual(expected));
793793
}
794+
795+
/// Check that after initializing a fused tensor to zero that the scale and
796+
/// offset are not changed and that the values for each row are set to that
797+
/// row's offset.
798+
TEST(Tensor, initZeroFused) {
799+
Tensor T(ElemKind::Int8FusedQTy, {10, 10}, 0.0, 0);
800+
auto TH = T.getHandle<int8_t>();
801+
TH.clear(127);
802+
for (size_t i = 0; i < 10; i++) {
803+
for (size_t j = 2; j < 10; j++) {
804+
// Set 6 due to endianess when loading the int32_t offset.
805+
if (j == 6) {
806+
TH.at({i, j}) = i + 100;
807+
} else {
808+
TH.at({i, j}) = 0;
809+
}
810+
}
811+
}
812+
PseudoRNG PRNG;
813+
T.init(Tensor::InitKind::Zero, 1, PRNG);
814+
for (size_t i = 0; i < 10; i++) {
815+
for (size_t j = 0; j < 10; j++) {
816+
// Now check that both the offset and the values are correct, and that all
817+
// other values are still 0.
818+
if (j < 2 || j == 6) {
819+
EXPECT_EQ(TH.at({i, j}), i + 100);
820+
} else {
821+
EXPECT_EQ(TH.at({i, j}), 0);
822+
}
823+
}
824+
}
825+
}
826+
827+
/// Check that initializing a fused tensor with Xavier that the scale and offset
828+
/// are not changed.
829+
TEST(Tensor, initXavierFused) {
830+
Tensor T(ElemKind::Int8FusedQTy, {10, 10}, 0.0, 0);
831+
PseudoRNG PRNG;
832+
auto TH = T.getHandle<int8_t>();
833+
for (size_t i = 0; i < 10; i++) {
834+
for (size_t j = 0; j < 10; j++) {
835+
TH.at({i, j}) = i * 10 + j;
836+
}
837+
}
838+
T.init(Tensor::InitKind::Xavier, 1, PRNG);
839+
for (size_t i = 0; i < 10; i++) {
840+
for (size_t j = 2; j < 10; j++) {
841+
// Check that the scales/offsets are unchanged.
842+
EXPECT_EQ(TH.at({i, j}), i * 10 + j);
843+
}
844+
}
845+
}
846+
847+
/// Check that initializing a fused tensor with Broadcast that the scale and
848+
/// offset are not changed, and broadcast value is set correctly.
849+
TEST(Tensor, initBroadcastFused) {
850+
Tensor T(ElemKind::Int8FusedQTy, {10, 10}, 0.0, 0);
851+
auto TH = T.getHandle<int8_t>();
852+
for (size_t i = 0; i < 10; i++) {
853+
for (size_t j = 0; j < 10; j++) {
854+
TH.at({i, j}) = i * 10 + j;
855+
}
856+
}
857+
PseudoRNG PRNG;
858+
T.init(Tensor::InitKind::Broadcast, 5, PRNG);
859+
for (size_t i = 0; i < 10; i++) {
860+
for (size_t j = 0; j < 10; j++) {
861+
// Check that the scales/offsets are unchanged, and that the broadcast
862+
// value is everywhere else.
863+
if (j < 2) {
864+
EXPECT_EQ(TH.at({i, j}), 5);
865+
} else {
866+
EXPECT_EQ(TH.at({i, j}), i * 10 + j);
867+
}
868+
}
869+
}
870+
}

0 commit comments

Comments
 (0)