@@ -360,12 +360,12 @@ MlasDequantizeBlockwise(
360
360
);
361
361
362
362
/* *
363
- * @brief Blockwise 2 bits or 4 bits quantization. After quantization, the weights and zero points
364
- * are packed row-wise. In terms of the qbits type, dst and src have the same shape, and
365
- * scales and zero_points have the same shape .
366
- * columns must be multiple of 8 / qbits .
363
+ * @brief Blockwise 4 bits quantization. After quantization, the weights and zero points
364
+ * are packed row-wise. If zero_points is null, quantized type is int4 with default
365
+ * zero point 0, to align with DQ schema. Otherwise, quantized type is uint4 .
366
+ * In int4/uint4, dst have the same shape as src, and zero_points have the same shape as scales .
367
367
* @tparam Tin
368
- * @tparam qbits number of bits used for quantization, 2 or 4
368
+ * @tparam qbits number of bits used for quantization, only 4 is supported
369
369
* @param src points to the floating point matrix, to be quantized, row major shape [rows, columns]
370
370
* @param scales points to the scales matrix, row major
371
371
* @param zero_points points to the zero_points matrix, row major
@@ -376,9 +376,10 @@ MlasDequantizeBlockwise(
376
376
* @param columns
377
377
* @param quant_block_size number of elements in a quantize block
378
378
* @param thread_pool
379
+ * @return the quantized type is signed.
379
380
*/
380
381
template <typename Tin, int qbits>
381
- void
382
+ bool
382
383
MlasQDQQuantizeBlockwise (
383
384
const Tin* src,
384
385
Tin* scales,
@@ -395,8 +396,17 @@ MlasQDQQuantizeBlockwise(
395
396
* @brief Transpose blockwise quantized tensors. The src tensors are row major. src weights and zero
396
397
* points are packed row-wise. The dst tensors are column major. dst weights and zero points
397
398
* are packed column-wise.
399
+ * dst_weights and dst_zero_points are in uint4.
400
+ * If src_weights is int4 and has src_zero_points, src_weights and src_zero_points are
401
+ * converted to uint4 by adding 8.
402
+ * If src_weights is int4 and no src_zero_points, src_weights is converted to uint4 by adding 8.
403
+ * src_zero_points is 0 and dst_zero_points is 8.
404
+ * If src_weights is uint4 and has src_zero_points, just transpose.
405
+ * If src_weights is uint4 and no src_zero_points, caller must allocate dst_zero_points with
406
+ * 0 values. Otherwise exception is thrown.
398
407
* @tparam Tin
399
- * @tparam qbits number of bits used for quantization, 2 or 4
408
+ * @tparam qbits number of bits used for quantization, only 4 is supported
409
+ * @tparam signed_quant true when quantized type is signed, false when quantized type is unsigned
400
410
* @param src_weights points to the quantized matrix, row major, shape [rows, columns] in qbits type.
401
411
* In uint8_t type, shape is [rows, columns * qbits / 8].
402
412
* @param src_scales points to the scales matrix, row major
@@ -410,7 +420,7 @@ MlasQDQQuantizeBlockwise(
410
420
* @param quant_block_size number of elements in a quantize block
411
421
* @param thread_pool
412
422
*/
413
- template <typename Tin, int qbits>
423
+ template <typename Tin, int qbits, bool signed_quant >
414
424
void
415
425
MlasQDQTransposeBlockwiseQuantized (
416
426
const uint8_t * src_weights,
0 commit comments