Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 111 additions & 17 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,14 @@ class DataType {
if (code == kBFloat) {
ICHECK_EQ(bits, 16);
}
if (code == kFloat8_e4m3fn || code == kFloat8_e5m2) {
if (code == kFloat8_e3m4 || code == kFloat8_e4m3 || code == kFloat8_e4m3b11fnuz ||
code == kFloat8_e4m3fn || code == kFloat8_e4m3fnuz || code == kFloat8_e5m2 ||
code == kFloat8_e5m2fnuz || code == kFloat8_e8m0fnu) {
ICHECK_EQ(bits, 8);
}
if (code == kFloat6_e2m3fn || code == kFloat6_e3m2fn) {
ICHECK_EQ(bits, 6);
}
if (code == kFloat4_e2m1fn) {
ICHECK_EQ(bits, 4);
}
Expand Down Expand Up @@ -138,17 +143,45 @@ class DataType {
bool is_float() const { return code() == DataType::kFloat; }
/*! \return whether type is a bfloat type. */
bool is_bfloat() const { return code() == DataType::kBFloat; }
/*! \return whether type is a float8 type. */
/*! \return whether type is any 8-bit custom Float8 variant. */
bool is_float8() const {
return (code() == DataType::kFloat || code() == DataType::kFloat8_e4m3fn ||
code() == DataType::kFloat8_e5m2) &&
bits() == 8;
return bits() == 8 &&
(code() == DataType::kFloat8_e3m4 || code() == DataType::kFloat8_e4m3 ||
code() == DataType::kFloat8_e4m3b11fnuz || code() == DataType::kFloat8_e4m3fn ||
code() == DataType::kFloat8_e4m3fnuz || code() == DataType::kFloat8_e5m2 ||
code() == DataType::kFloat8_e5m2fnuz || code() == DataType::kFloat8_e8m0fnu);
}
/*! \return whether type is any 6-bit custom Float6 variant. */
bool is_float6() const {
return bits() == 6 &&
(code() == DataType::kFloat6_e2m3fn || code() == DataType::kFloat6_e3m2fn);
}
/*! \return whether type is the 4-bit custom Float4_e2m1fn variant. */
bool is_float4() const { return bits() == 4 && code() == DataType::kFloat4_e2m1fn; }
/*! \return whether type is Float8E3M4. */
bool is_float8_e3m4() const { return bits() == 8 && code() == DataType::kFloat8_e3m4; }
/*! \return whether type is Float8E4M3. */
bool is_float8_e4m3() const { return bits() == 8 && code() == DataType::kFloat8_e4m3; }
/*! \return whether type is Float8E4M3B11FNUZ. */
bool is_float8_e4m3b11fnuz() const {
return bits() == 8 && code() == DataType::kFloat8_e4m3b11fnuz;
}
/*! \return whether type is a float4 type. */
bool is_float4() const { return code() == DataType::kFloat4_e2m1fn && bits() == 4; }
bool is_float8_e4m3fn() const { return (code() == DataType::kFloat8_e4m3fn && bits() == 8); }
bool is_float8_e5m2() const { return (code() == DataType::kFloat8_e5m2 && bits() == 8); }
bool is_float4_e2m1fn() const { return (code() == DataType::kFloat4_e2m1fn && bits() == 4); }
/*! \return whether type is Float8E4M3FN. */
bool is_float8_e4m3fn() const { return bits() == 8 && code() == DataType::kFloat8_e4m3fn; }
/*! \return whether type is Float8E4M3FNUZ. */
bool is_float8_e4m3fnuz() const { return bits() == 8 && code() == DataType::kFloat8_e4m3fnuz; }
/*! \return whether type is Float8E5M2. */
bool is_float8_e5m2() const { return bits() == 8 && code() == DataType::kFloat8_e5m2; }
/*! \return whether type is Float8E5M2FNUZ. */
bool is_float8_e5m2fnuz() const { return bits() == 8 && code() == DataType::kFloat8_e5m2fnuz; }
/*! \return whether type is Float8E8M0FNU. */
bool is_float8_e8m0fnu() const { return bits() == 8 && code() == DataType::kFloat8_e8m0fnu; }
/*! \return whether type is Float6E2M3FN. */
bool is_float6_e2m3fn() const { return bits() == 6 && code() == DataType::kFloat6_e2m3fn; }
/*! \return whether type is Float6E3M2FN. */
bool is_float6_e3m2fn() const { return bits() == 6 && code() == DataType::kFloat6_e3m2fn; }
/*! \return whether type is Float4E2M1FN. */
bool is_float4_e2m1fn() const { return bits() == 4 && code() == DataType::kFloat4_e2m1fn; }
/*! \return whether type is a float16 type. */
bool is_float16() const { return is_float() && bits() == 16; }
/*! \return whether type is a bfloat16 type. */
Expand Down Expand Up @@ -262,23 +295,83 @@ class DataType {
*/
static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); }
/*!
* \brief Construct NV float8 e4m3 datatype.
* \brief Construct float8 e3m4 datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType Float8E3M4(int lanes = 1) { return DataType(kFloat8_e3m4, 8, lanes); }

/*!
* \brief Construct float8 e4m3 datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType Float8E4M3(int lanes = 1) { return DataType(kFloat8_e4m3, 8, lanes); }

/*!
* \brief Construct float8 e4m3b11fnuz datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType Float8E4M3B11FNUZ(int lanes = 1) {
return DataType(kFloat8_e4m3b11fnuz, 8, lanes);
}

/*!
* \brief Construct float8 e4m3fn datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType Float8E4M3FN(int lanes = 1) { return DataType(kFloat8_e4m3fn, 8, lanes); }

/*!
* \brief Construct float8 e4m3fnuz datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType Float8E4M3FNUZ(int lanes = 1) { return DataType(kFloat8_e4m3fnuz, 8, lanes); }

/*!
* \brief Construct float8 e5m2 datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kFloat8_e4m3fn, 8, lanes); }
static DataType Float8E5M2(int lanes = 1) { return DataType(kFloat8_e5m2, 8, lanes); }

/*!
* \brief Construct float8 e5m2fnuz datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType Float8E5M2FNUZ(int lanes = 1) { return DataType(kFloat8_e5m2fnuz, 8, lanes); }

/*!
* \brief Construct NV float8 e5m2 datatype.
* \brief Construct float8 e8m0fnu datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kFloat8_e5m2, 8, lanes); }
static DataType Float8E8M0FNU(int lanes = 1) { return DataType(kFloat8_e8m0fnu, 8, lanes); }

/*!
* \brief Construct float6 e2m3fn datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType Float6E2M3FN(int lanes = 1) { return DataType(kFloat6_e2m3fn, 6, lanes); }

/*!
* \brief Construct float6 e3m2fn datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType Float6E3M2FN(int lanes = 1) { return DataType(kFloat6_e3m2fn, 6, lanes); }

/*!
* \brief Construct NV float4_e2m1fn datatype.
* \brief Construct float4 e2m1fn datatype.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType NVFloat4E2M1FN(int lanes = 1) { return DataType(kFloat4_e2m1fn, 4, lanes); }
static DataType Float4E2M1FN(int lanes = 1) { return DataType(kFloat4_e2m1fn, 4, lanes); }
/*!
* \brief Construct a bool type.
* \param lanes The number of lanes.
Expand Down Expand Up @@ -325,7 +418,8 @@ inline int GetVectorBytes(DataType dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) ||
dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1FN()) {
dtype == DataType::Int(1) || dtype == DataType::Float4E2M1FN() ||
dtype == DataType::Float6E2M3FN() || dtype == DataType::Float6E3M2FN()) {
return 1;
}
ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes";
Expand Down
17 changes: 13 additions & 4 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -504,10 +504,19 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64));

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FN, DataType::NVFloat8E4M3);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2, DataType::NVFloat8E5M2);

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1FN, DataType::NVFloat4E2M1FN);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E3M4, DataType::Float8E3M4);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3, DataType::Float8E4M3);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3B11FNUZ, DataType::Float8E4M3B11FNUZ);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FN, DataType::Float8E4M3FN);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FNUZ, DataType::Float8E4M3FNUZ);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2, DataType::Float8E5M2);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2FNUZ, DataType::Float8E5M2FNUZ);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E8M0FNU, DataType::Float8E8M0FNU);

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E2M3FN, DataType::Float6E2M3FN);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E3M2FN, DataType::Float6E3M2FN);

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1FN, DataType::Float4E2M1FN);

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
}
}
if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float4())
if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float6() || t.is_float4())
return FloatImm(t, static_cast<double>(value), span);
// For now, we store const scalar values of custom datatypes within doubles; later, during the
// datatypes lowering pass, we will lower the value to its true representation in the format
Expand Down
70 changes: 30 additions & 40 deletions python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,9 @@ def copyfrom(self, source_array):
source_array = np.ascontiguousarray(
source_array, dtype="uint16" if dtype == "bfloat16" else dtype
)
if self.dtype.startswith("float4_e2m1fn") and self.dtype != "float4_e2m1fn":
# float4_e2m1fn in numpy is not packed.
# So we need to pack the input data when converting to vectorized float4_e2m1fn type.
data_bits = source_array.view(dtype="uint8")
if self.dtype.startswith("float4_e2m1fn"):
# we need to pack the input data when converting to float4_e2m1fn type,
data_bits = source_array.view(dtype="uint8").flatten()
if data_bits.size % 2:
data_bits = np.pad(data_bits, (0, 1), mode="constant", constant_values=0)
data_bits = data_bits.reshape(-1, 2)
Expand Down Expand Up @@ -189,54 +188,45 @@ def numpy(self):
dtype = str(t)
if dtype == "int4":
dtype = "int8"
if dtype == "bfloat16":
if ml_dtypes is not None:
dtype = ml_dtypes.bfloat16
else:
if dtype in [
"bfloat16",
"float8_e3m4",
"float8_e4m3",
"float8_e4m3b11fnuz",
"float8_e4m3fn",
"float8_e4m3fnuz",
"float8_e5m2",
"float8_e5m2fnuz",
"float8_e8m0fnu",
"float6_e2m3fn",
"float6_e3m2fn",
"float4_e2m1fn",
]:
if ml_dtypes is None:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert bfloat16 array to numpy."
)
if dtype == "float8_e4m3fn":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e4m3fn
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert float8_e4m3fn array to numpy."
)
if dtype == "float8_e5m2":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e5m2
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert float8_e5m2 array to numpy."
)
if dtype == "float4_e2m1fn":
if ml_dtypes is not None:
dtype = ml_dtypes.float4_e2m1fn
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert float4_e2m1fn array to numpy."
f"ml_dtypes is not installed, cannot convert {dtype} array to numpy."
)
try:
dtype = getattr(ml_dtypes, dtype)
except AttributeError:
raise RuntimeError(f"ml_dtypes has no attribute '{dtype}', cannot convert array.")
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags["C_CONTIGUOUS"]
data = np_arr.ctypes.data_as(ctypes.c_void_p)
if old_dtype.startswith("float4_e2m1fn") and old_dtype != "float4_e2m1fn":
nbytes = np_arr.size * np_arr.dtype.itemsize // 2
else:
nbytes = np_arr.size * np_arr.dtype.itemsize
# TODO(kathy): revisit and get a mirrored function of ffi::GetDataSize
# in Python to replace line below
nbytes = np_arr.size if dtype == "bool" else (np_arr.size * old_dtype.bits + 7) // 8
_ffi_api.TVMArrayCopyToBytes(self, data, nbytes)

if old_dtype == "int4" or (
old_dtype.startswith("float4_e2m1fn") and old_dtype != "float4_e2m1fn"
):
if old_dtype == "int4" or old_dtype.startswith("float4_e2m1fn"):
length = np_arr.size
np_arr = np_arr.view("int8")
np_arr_ret = np.empty((length,), dtype="int8")
np_arr = np_arr.reshape((length,))
old_index = np.bitwise_and(np_arr, 0x0F)
odd_index = np.bitwise_and(np_arr, 0x0F)
even_index = np.bitwise_and(np_arr >> 4, 0x0F)
np_arr_ret[1::2] = old_index[0 : length // 2]
np_arr_ret[0::2] = even_index[0 : length // 2]
np_arr_ret[1::2] = odd_index[0 : length // 2]
np_arr_ret[0::2] = even_index[0 : (length + 1) // 2]
return np_arr_ret.reshape(shape).view(dtype)

return np_arr
Expand Down
Loading