Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/common/transformer_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ struct DecoderContext {
float attFactor;
float epsilon;

// quantization configuration
int groupsize;

// rope scaling parameters
RopeParams *ropeParamsPtr;

Expand Down Expand Up @@ -132,7 +135,7 @@ struct DecoderContext {
DecoderContext(int _layers, int _hiddenSize, int _headSize, int _attHeadNum, int _kvHeadNum, int _imSize, const std::string &act,
float epsilon, int _vocabSize, int _embeddingSize, int _maxPositions, int _maxPosEmbed, int _maxSeqLength,
int _splitIdx, int _splits, MMHelper *mmHelper, void *device = nullptr, int _ppSize = 1, int _ppRank = 0, RopeParams *_ropeParamsPtr = nullptr,
bool _useLogN = true, bool _useNTK = true, int numThreads = 0)
bool _useLogN = true, bool _useNTK = true, int numThreads = 0, int _groupsize = -1)
: layers(_layers)
, hiddenSize(_hiddenSize)
, attHeadSize(_headSize)
Expand All @@ -153,7 +156,8 @@ struct DecoderContext {
, ppRank(_ppRank)
, tpSize(_splits)
, tpRank(_splitIdx)
, epsilon(epsilon) {
, epsilon(epsilon)
, groupsize(_groupsize) {
if (attHeadNum != 0) {
this->attFactor = 1 / sqrtf(attHeadSize);
}
Expand Down Expand Up @@ -325,4 +329,4 @@ struct DecoderContext {
if (this->rawBuffer) xft::dealloc(this->rawBuffer, this->device);
#endif
}
};
};
56 changes: 30 additions & 26 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,27 @@ class Attention {
float *concatScale = nullptr;
float *concatZero = nullptr;
if constexpr (std::is_same_v<OriWeiT, int8_t> || std::is_same_v<OriWeiT, uint4x2_t>) {
concatScale = (float *)malloc(responsibleCols * sizeof(float));
concatZero = (float *)malloc(responsibleCols * sizeof(float));
memcpy(concatScale, queryScale + this->startQHead * headSize, qResponsibleCols * sizeof(float));
memcpy(concatScale + qResponsibleCols, keyScale + this->startKVHead * headSize,
int qkvStride = (ctx->attHeadNum + ctx->kvHeadNum + ctx->kvHeadNum) * ctx->attHeadSize;
int groups = ctx->groupsize == -1 ? 1 : hiddenSize / ctx->groupsize;
concatScale = (float *)malloc(groups * responsibleCols * sizeof(float));
concatZero = (float *)malloc(groups * responsibleCols * sizeof(float));
for (int i = 0; i < groups; ++i) {
memcpy(concatScale + i * responsibleCols, queryScale + i * qkvStride + this->startQHead * headSize, qResponsibleCols * sizeof(float));
memcpy(concatScale + i * responsibleCols + qResponsibleCols, keyScale + i * qkvStride + this->startKVHead * headSize,
kvResponsibleCols * sizeof(float));
memcpy(concatScale + qResponsibleCols + kvResponsibleCols, valueScale + this->startKVHead * headSize,
memcpy(concatScale + i * responsibleCols + qResponsibleCols + kvResponsibleCols, valueScale + i * qkvStride + this->startKVHead * headSize,
kvResponsibleCols * sizeof(float));
memcpy(concatZero, queryZero + this->startQHead * headSize, qResponsibleCols * sizeof(float));
memcpy(concatZero + qResponsibleCols, keyZero + this->startKVHead * headSize,
memcpy(concatZero + i * responsibleCols, queryZero + i * qkvStride + this->startQHead * headSize, qResponsibleCols * sizeof(float));
memcpy(concatZero + i * responsibleCols + qResponsibleCols, keyZero + i * qkvStride + this->startKVHead * headSize,
kvResponsibleCols * sizeof(float));
memcpy(concatZero + qResponsibleCols + kvResponsibleCols, valueZero + this->startKVHead * headSize,
memcpy(concatZero + i * responsibleCols + qResponsibleCols + kvResponsibleCols, valueZero + i * qkvStride + this->startKVHead * headSize,
kvResponsibleCols * sizeof(float));
}
}

xft::Matrix<WeiT> convertedqkvWeight;
ctx->mmHelper->convertWeight(trans, hiddenSize, responsibleCols, concatBuf, concatScale, concatZero,
convertedqkvWeight, qkvWeightScale, qkvWeightZero, qkvWeightSum);
convertedqkvWeight, qkvWeightScale, qkvWeightZero, qkvWeightSum, ctx->groupsize);

#ifdef XFT_GPU
xft::Matrix<WeiT> qkvWeightT;
Expand Down Expand Up @@ -182,7 +186,7 @@ class Attention {
xft::Matrix<WeiT> convertedOutWeight;
ctx->mmHelper->convertWeight(trans, ctx->attHeadNum * ctx->attHeadSize, hiddenSize, attnOutWeight, attnOutScale,
attnOutZero, this->startQHead * headSize, qResponsibleCols, false, convertedOutWeight,
attnOutputWeightScale, attnOutputWeightZero, attnOutputWeightSum, true);
attnOutputWeightScale, attnOutputWeightZero, attnOutputWeightSum, ctx->groupsize, true);

#ifdef XFT_GPU
xft::Matrix<WeiT> outWeightT;
Expand Down Expand Up @@ -289,11 +293,11 @@ class Attention {
if (qkvBias.Size() == 0) {
ctx->mmHelper->compute(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f, imBuffer.Data(),
imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(),
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride());
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), ctx->groupsize);
} else {
ctx->mmHelper->compute_bias(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f,
imBuffer.Data(), imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(),
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), qkvBias.Data());
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), qkvBias.Data(), ctx->groupsize);
}
t2.release();

Expand Down Expand Up @@ -405,26 +409,26 @@ class Attention {
ctx->mmHelper->compute_residential(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(),
1.0f, attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(),
attnOutputWeightScale.Data(), attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f,
outBuffer.Data(), outBuffer.Stride(), pbias, inputBuffer.Data(), inputBuffer.Stride());
outBuffer.Data(), outBuffer.Stride(), pbias, inputBuffer.Data(), inputBuffer.Stride(), ctx->groupsize);
} else {
float *pbias = attnOutputBias.Data();
if (attnOutputBias.Size() == 0) { pbias = nullptr; }
ctx->mmHelper->compute_resext(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride(), pbias, gamma, inputBuffer.Data(), inputBuffer.Stride());
outBuffer.Stride(), pbias, gamma, inputBuffer.Data(), inputBuffer.Stride(), ctx->groupsize);
}
} else {
if (attnOutputBias.Size() == 0) {
ctx->mmHelper->compute(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride());
outBuffer.Stride(), ctx->groupsize);
} else {
ctx->mmHelper->compute_bias(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride(), attnOutputBias.Data());
outBuffer.Stride(), attnOutputBias.Data(), ctx->groupsize);
}
}
t5.release();
Expand Down Expand Up @@ -495,11 +499,11 @@ class Attention {
if (qkvBias.Size() == 0) {
ctx->mmHelper->compute(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f, imBuffer.Data(),
imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(),
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride());
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), ctx->groupsize);
} else {
ctx->mmHelper->compute_bias(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f,
imBuffer.Data(), imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(),
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), qkvBias.Data());
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), qkvBias.Data(), ctx->groupsize);
}
t2.release();

Expand Down Expand Up @@ -588,26 +592,26 @@ class Attention {
ctx->mmHelper->compute_residential(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(),
1.0f, attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(),
attnOutputWeightScale.Data(), attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f,
outBuffer.Data(), outBuffer.Stride(), pbias, inputBuffer.Data(), inputBuffer.Stride());
outBuffer.Data(), outBuffer.Stride(), pbias, inputBuffer.Data(), inputBuffer.Stride(), ctx->groupsize);
} else {
float *pbias = attnOutputBias.Data();
if (attnOutputBias.Size() == 0) { pbias = nullptr; }
ctx->mmHelper->compute_resext(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride(), pbias, gamma, inputBuffer.Data(), inputBuffer.Stride());
outBuffer.Stride(), pbias, gamma, inputBuffer.Data(), inputBuffer.Stride(), ctx->groupsize);
}
} else {
if (attnOutputBias.Size() == 0) {
ctx->mmHelper->compute(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride());
outBuffer.Stride(), ctx->groupsize);
} else {
ctx->mmHelper->compute_bias(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride(), attnOutputBias.Data());
outBuffer.Stride(), attnOutputBias.Data(), ctx->groupsize);
}
}
t5.release();
Expand Down Expand Up @@ -1183,15 +1187,15 @@ class Attention {

// query, key, value weighs
xft::Matrix<WeiT> qkvWeight;
xft::Vector<float> qkvWeightScale; // if weight is int8
xft::Vector<float> qkvWeightZero; // if weight is int8
xft::Matrix<float> qkvWeightScale; // if weight is int8
xft::Matrix<float> qkvWeightZero; // if weight is int8
xft::Vector<float> qkvWeightSum; // if weight is int8
// query, key, value bias
xft::Vector<float> qkvBias;

xft::Matrix<WeiT> attnOutputWeight;
xft::Vector<float> attnOutputWeightScale; // if weight is int8
xft::Vector<float> attnOutputWeightZero; // if weight is int8
xft::Matrix<float> attnOutputWeightScale; // if weight is int8
xft::Matrix<float> attnOutputWeightZero; // if weight is int8
xft::Vector<float> attnOutputWeightSum; // if weight is int8
xft::Vector<float> attnOutputBias;

Expand Down
Loading