From 1c36432ef5d6e57985b949195f2c024396bcd07c Mon Sep 17 00:00:00 2001 From: Kathryn-cat Date: Fri, 30 May 2025 16:26:16 -0400 Subject: [PATCH 1/4] addressed comments Co-authored-by: DerrickYLJ --- include/tvm/runtime/data_type.h | 120 +++++++++++++-- include/tvm/script/ir_builder/tir/ir.h | 12 +- include/tvm/tir/op.h | 2 +- python/tvm/runtime/ndarray.py | 68 ++++----- python/tvm/script/ir_builder/tir/ir.py | 139 +++++++++++++----- src/ir/expr.cc | 75 ++++++++-- src/runtime/device_api.cc | 3 +- src/runtime/ndarray.cc | 4 + src/script/ir_builder/tir/ir.cc | 34 ++++- src/support/scalars.h | 30 +++- src/target/llvm/codegen_llvm.cc | 9 +- src/target/source/codegen_cuda.cc | 97 +++++++++++- src/target/source/codegen_cuda.h | 6 +- src/target/source/literal/cuda_half_t.h | 46 ++++++ src/tir/op/op.cc | 41 ++++++ src/tir/transforms/dtype_conversion.cc | 3 +- src/tir/transforms/dtype_conversion.h | 38 ++++- .../codegen/test_target_codegen_cuda_fp4.py | 130 +--------------- .../codegen/test_target_codegen_cuda_fp8.py | 135 +++++++++-------- tests/python/ffi/test_dtype.py | 31 +++- tests/python/ir/test_datatype_nv_fp4.py | 52 +++++++ tests/python/ir/test_datatype_nv_fp8.py | 31 +++- tests/python/ir/test_dtype.py | 9 +- .../tvmscript/test_tvmscript_printer_tir.py | 43 +++--- 24 files changed, 818 insertions(+), 340 deletions(-) create mode 100644 tests/python/ir/test_datatype_nv_fp4.py diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index d5f3c6ee3d7f..9e9bcddc2957 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -97,9 +97,14 @@ class DataType { if (code == kBFloat) { ICHECK_EQ(bits, 16); } - if (code == kFloat8_e4m3fn || code == kFloat8_e5m2) { + if (code == kFloat8_e3m4 || code == kFloat8_e4m3 || code == kFloat8_e4m3b11fnuz || + code == kFloat8_e4m3fn || code == kFloat8_e4m3fnuz || code == kFloat8_e5m2 || + code == kFloat8_e5m2fnuz || code == kFloat8_e8m0fnu) { ICHECK_EQ(bits, 8); } + if (code == kFloat6_e2m3fn || code == kFloat6_e3m2fn) { + ICHECK_EQ(bits, 6); + } if (code == kFloat4_e2m1fn) { ICHECK_EQ(bits, 4); } @@ -138,17 +143,45 @@ class DataType { bool is_float() const { return code() == DataType::kFloat; } /*! \return whether type is a bfloat type. */ bool is_bfloat() const { return code() == DataType::kBFloat; } - /*! \return whether type is a float8 type. */ + /*! \return whether type is any 8-bit custom Float8 variant. */ bool is_float8() const { - return (code() == DataType::kFloat || code() == DataType::kFloat8_e4m3fn || - code() == DataType::kFloat8_e5m2) && - bits() == 8; + return bits() == 8 && + (code() == DataType::kFloat8_e3m4 || code() == DataType::kFloat8_e4m3 || + code() == DataType::kFloat8_e4m3b11fnuz || code() == DataType::kFloat8_e4m3fn || + code() == DataType::kFloat8_e4m3fnuz || code() == DataType::kFloat8_e5m2 || + code() == DataType::kFloat8_e5m2fnuz || code() == DataType::kFloat8_e8m0fnu); + } + /*! \return whether type is any 6-bit custom Float6 variant. */ + bool is_float6() const { + return bits() == 6 && + (code() == DataType::kFloat6_e2m3fn || code() == DataType::kFloat6_e3m2fn); + } + /*! \return whether type is the 4-bit custom Float4_e2m1fn variant. */ + bool is_float4() const { return bits() == 4 && code() == DataType::kFloat4_e2m1fn; } + /*! \return whether type is Float8E3M4. */ + bool is_float8_e3m4() const { return bits() == 8 && code() == DataType::kFloat8_e3m4; } + /*! \return whether type is Float8E4M3. */ + bool is_float8_e4m3() const { return bits() == 8 && code() == DataType::kFloat8_e4m3; } + /*! \return whether type is Float8E4M3B11FNUZ. */ + bool is_float8_e4m3b11fnuz() const { + return bits() == 8 && code() == DataType::kFloat8_e4m3b11fnuz; } - /*! \return whether type is a float4 type. */ - bool is_float4() const { return code() == DataType::kFloat4_e2m1fn && bits() == 4; } - bool is_float8_e4m3fn() const { return (code() == DataType::kFloat8_e4m3fn && bits() == 8); } - bool is_float8_e5m2() const { return (code() == DataType::kFloat8_e5m2 && bits() == 8); } - bool is_float4_e2m1fn() const { return (code() == DataType::kFloat4_e2m1fn && bits() == 4); } + /*! \return whether type is Float8E4M3FN. */ + bool is_float8_e4m3fn() const { return bits() == 8 && code() == DataType::kFloat8_e4m3fn; } + /*! \return whether type is Float8E4M3FNUZ. */ + bool is_float8_e4m3fnuz() const { return bits() == 8 && code() == DataType::kFloat8_e4m3fnuz; } + /*! \return whether type is Float8E5M2. */ + bool is_float8_e5m2() const { return bits() == 8 && code() == DataType::kFloat8_e5m2; } + /*! \return whether type is Float8E5M2FNUZ. */ + bool is_float8_e5m2fnuz() const { return bits() == 8 && code() == DataType::kFloat8_e5m2fnuz; } + /*! \return whether type is Float8E8M0FNU. */ + bool is_float8_e8m0fnu() const { return bits() == 8 && code() == DataType::kFloat8_e8m0fnu; } + /*! \return whether type is Float6E2M3FN. */ + bool is_float6_e2m3fn() const { return bits() == 6 && code() == DataType::kFloat6_e2m3fn; } + /*! \return whether type is Float6E3M2FN. */ + bool is_float6_e3m2fn() const { return bits() == 6 && code() == DataType::kFloat6_e3m2fn; } + /*! \return whether type is Float4E2M1FN. */ + bool is_float4_e2m1fn() const { return bits() == 4 && code() == DataType::kFloat4_e2m1fn; } /*! \return whether type is a float16 type. */ bool is_float16() const { return is_float() && bits() == 16; } /*! \return whether type is a bfloat16 type. */ @@ -261,20 +294,80 @@ class DataType { * \return The constructed data type. */ static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); } + /*! + * \brief Construct NV float8 e3m4 datatype. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType NVFloat8E3M4(int lanes = 1) { return DataType(kFloat8_e3m4, 8, lanes); } + /*! * \brief Construct NV float8 e4m3 datatype. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kFloat8_e4m3fn, 8, lanes); } + static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kFloat8_e4m3, 8, lanes); } + + /*! + * \brief Construct NV float8 e4m3b11fnuz datatype. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType NVFloat8E4M3B11FNUZ(int lanes = 1) { + return DataType(kFloat8_e4m3b11fnuz, 8, lanes); + } + + /*! + * \brief Construct NV float8 e4m3fn datatype. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType NVFloat8E4M3FN(int lanes = 1) { return DataType(kFloat8_e4m3fn, 8, lanes); } + + /*! + * \brief Construct NV float8 e4m3fnuz datatype. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType NVFloat8E4M3FNUZ(int lanes = 1) { return DataType(kFloat8_e4m3fnuz, 8, lanes); } + /*! * \brief Construct NV float8 e5m2 datatype. * \param lanes The number of lanes * \return The constructed data type. */ static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kFloat8_e5m2, 8, lanes); } + + /*! + * \brief Construct NV float8 e5m2fnuz datatype. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType NVFloat8E5M2FNUZ(int lanes = 1) { return DataType(kFloat8_e5m2fnuz, 8, lanes); } + + /*! + * \brief Construct NV float8 e8m0fnu datatype. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType NVFloat8E8M0FNU(int lanes = 1) { return DataType(kFloat8_e8m0fnu, 8, lanes); } + + /*! + * \brief Construct NV float6 e2m3fn datatype. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType NVFloat6E2M3FN(int lanes = 1) { return DataType(kFloat6_e2m3fn, 6, lanes); } + + /*! + * \brief Construct NV float6 e3m2fn datatype. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType NVFloat6E3M2FN(int lanes = 1) { return DataType(kFloat6_e3m2fn, 6, lanes); } + /*! - * \brief Construct NV float4_e2m1fn datatype. + * \brief Construct NV float4 e2m1fn datatype. * \param lanes The number of lanes * \return The constructed data type. */ @@ -325,7 +418,8 @@ inline int GetVectorBytes(DataType dtype) { int data_bits = dtype.bits() * dtype.lanes(); // allow bool to exist if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) || - dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1FN()) { + dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1FN() || + dtype == DataType::NVFloat6E2M3FN() || dtype == DataType::NVFloat6E3M2FN()) { return 1; } ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes"; diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index d3eb8ac435d5..cfbaac7dc2a0 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -504,8 +504,18 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64)); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FN, DataType::NVFloat8E4M3); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E3M4, DataType::NVFloat8E3M4); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3, DataType::NVFloat8E4M3); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3B11FNUZ, + DataType::NVFloat8E4M3B11FNUZ); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FN, DataType::NVFloat8E4M3FN); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FNUZ, DataType::NVFloat8E4M3FNUZ); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2, DataType::NVFloat8E5M2); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2FNUZ, DataType::NVFloat8E5M2FNUZ); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E8M0FNU, DataType::NVFloat8E8M0FNU); + +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E2M3FN, DataType::NVFloat6E2M3FN); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E3M2FN, DataType::NVFloat6E3M2FN); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1FN, DataType::NVFloat4E2M1FN); diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index ce7a425c94f9..99139f83b297 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -950,7 +950,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) return LargeUIntImm(t, static_cast(low), static_cast(high), span); } } - if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float4()) + if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float6() || t.is_float4()) return FloatImm(t, static_cast(value), span); // For now, we store const scalar values of custom datatypes within doubles; later, during the // datatypes lowering pass, we will lower the value to its true representation in the format diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 9a026707cb48..78af2569d2e0 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -149,10 +149,9 @@ def copyfrom(self, source_array): source_array = np.ascontiguousarray( source_array, dtype="uint16" if dtype == "bfloat16" else dtype ) - if self.dtype.startswith("float4_e2m1fn") and self.dtype != "float4_e2m1fn": - # float4_e2m1fn in numpy is not packed. - # So we need to pack the input data when converting to vectorized float4_e2m1fn type. - data_bits = source_array.view(dtype="uint8") + if self.dtype.startswith("float4_e2m1fn"): + # we need to pack the input data when converting to float4_e2m1fn type, + data_bits = source_array.view(dtype="uint8").flatten() if data_bits.size % 2: data_bits = np.pad(data_bits, (0, 1), mode="constant", constant_values=0) data_bits = data_bits.reshape(-1, 2) @@ -189,54 +188,43 @@ def numpy(self): dtype = str(t) if dtype == "int4": dtype = "int8" - if dtype == "bfloat16": - if ml_dtypes is not None: - dtype = ml_dtypes.bfloat16 - else: + if dtype in [ + "bfloat16", + "float8_e3m4", + "float8_e4m3", + "float8_e4m3b11fnuz", + "float8_e4m3fn", + "float8_e4m3fnuz", + "float8_e5m2", + "float8_e5m2fnuz", + "float8_e8m0fnu", + "float6_e2m3fn", + "float6_e3m2fn", + "float4_e2m1fn", + ]: + if ml_dtypes is None: raise RuntimeError( - "ml_dtypes is not installed, cannot convert bfloat16 array to numpy." - ) - if dtype == "float8_e4m3fn": - if ml_dtypes is not None: - dtype = ml_dtypes.float8_e4m3fn - else: - raise RuntimeError( - "ml_dtypes is not installed, cannot convert float8_e4m3fn array to numpy." - ) - if dtype == "float8_e5m2": - if ml_dtypes is not None: - dtype = ml_dtypes.float8_e5m2 - else: - raise RuntimeError( - "ml_dtypes is not installed, cannot convert float8_e5m2 array to numpy." - ) - if dtype == "float4_e2m1fn": - if ml_dtypes is not None: - dtype = ml_dtypes.float4_e2m1fn - else: - raise RuntimeError( - "ml_dtypes is not installed, cannot convert float4_e2m1fn array to numpy." + f"ml_dtypes is not installed, cannot convert {dtype} array to numpy." ) + try: + dtype = getattr(ml_dtypes, dtype) + except AttributeError: + raise RuntimeError(f"ml_dtypes has no attribute '{dtype}', cannot convert array.") np_arr = np.empty(shape, dtype=dtype) assert np_arr.flags["C_CONTIGUOUS"] data = np_arr.ctypes.data_as(ctypes.c_void_p) - if old_dtype.startswith("float4_e2m1fn") and old_dtype != "float4_e2m1fn": - nbytes = np_arr.size * np_arr.dtype.itemsize // 2 - else: - nbytes = np_arr.size * np_arr.dtype.itemsize + nbytes = (np_arr.size * old_dtype.bits + 7) // 8 _ffi_api.TVMArrayCopyToBytes(self, data, nbytes) - if old_dtype == "int4" or ( - old_dtype.startswith("float4_e2m1fn") and old_dtype != "float4_e2m1fn" - ): + if old_dtype == "int4" or old_dtype.startswith("float4_e2m1fn"): length = np_arr.size np_arr = np_arr.view("int8") np_arr_ret = np.empty((length,), dtype="int8") np_arr = np_arr.reshape((length,)) - old_index = np.bitwise_and(np_arr, 0x0F) + odd_index = np.bitwise_and(np_arr, 0x0F) even_index = np.bitwise_and(np_arr >> 4, 0x0F) - np_arr_ret[1::2] = old_index[0 : length // 2] - np_arr_ret[0::2] = even_index[0 : length // 2] + np_arr_ret[1::2] = odd_index[0 : length // 2] + np_arr_ret[0::2] = even_index[0 : (length + 1) // 2] return np_arr_ret.reshape(shape).view(dtype) return np_arr diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index e270b9152643..1aaeaa034724 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1300,7 +1300,9 @@ def buffer_store( for index in indices: if isinstance(index, slice): step = 1 if index.step is None else index.step - lanes = Analyzer().simplify((index.stop - index.start + step - 1) // step) + lanes = Analyzer().simplify( # pylint: disable=redefined-outer-name + (index.stop - index.start + step - 1) // step + ) if lanes == 1: expr_indices.append(index.start) else: @@ -1426,6 +1428,9 @@ def func( float16 = func_gen(("Float16")) float32 = func_gen(("Float32")) float64 = func_gen(("Float64")) +float16x2 = func_gen(("Float16x2")) +float32x2 = func_gen(("Float32x2")) +float64x2 = func_gen(("Float64x2")) float16x4 = func_gen(("Float16x4")) float32x4 = func_gen(("Float32x4")) float64x4 = func_gen(("Float64x4")) @@ -1442,20 +1447,89 @@ def func( float32x64 = func_gen(("Float32x64")) float64x64 = func_gen(("Float64x64")) +# Float8 variants +float8_e3m4 = func_gen(("Float8E3M4")) +float8_e3m4x2 = func_gen(("Float8E3M4x2")) +float8_e3m4x4 = func_gen(("Float8E3M4x4")) +float8_e3m4x8 = func_gen(("Float8E3M4x8")) +float8_e3m4x16 = func_gen(("Float8E3M4x16")) +float8_e3m4x32 = func_gen(("Float8E3M4x32")) +float8_e3m4x64 = func_gen(("Float8E3M4x64")) + +float8_e4m3 = func_gen(("Float8E4M3")) +float8_e4m3x2 = func_gen(("Float8E4M3x2")) +float8_e4m3x4 = func_gen(("Float8E4M3x4")) +float8_e4m3x8 = func_gen(("Float8E4M3x8")) +float8_e4m3x16 = func_gen(("Float8E4M3x16")) +float8_e4m3x32 = func_gen(("Float8E4M3x32")) +float8_e4m3x64 = func_gen(("Float8E4M3x64")) + +float8_e4m3b11fnuz = func_gen(("Float8E4M3B11FNUZ")) +float8_e4m3b11fnuzx2 = func_gen(("Float8E4M3B11FNUZx2")) +float8_e4m3b11fnuzx4 = func_gen(("Float8E4M3B11FNUZx4")) +float8_e4m3b11fnuzx8 = func_gen(("Float8E4M3B11FNUZx8")) +float8_e4m3b11fnuzx16 = func_gen(("Float8E4M3B11FNUZx16")) +float8_e4m3b11fnuzx32 = func_gen(("Float8E4M3B11FNUZx32")) +float8_e4m3b11fnuzx64 = func_gen(("Float8E4M3B11FNUZx64")) + float8_e4m3fn = func_gen(("Float8E4M3FN")) +float8_e4m3fnx2 = func_gen(("Float8E4M3FNx2")) float8_e4m3fnx4 = func_gen(("Float8E4M3FNx4")) float8_e4m3fnx8 = func_gen(("Float8E4M3FNx8")) float8_e4m3fnx16 = func_gen(("Float8E4M3FNx16")) float8_e4m3fnx32 = func_gen(("Float8E4M3FNx32")) float8_e4m3fnx64 = func_gen(("Float8E4M3FNx64")) +float8_e4m3fnuz = func_gen(("Float8E4M3FNUZ")) +float8_e4m3fnuzx2 = func_gen(("Float8E4M3FNUZx2")) +float8_e4m3fnuzx4 = func_gen(("Float8E4M3FNUZx4")) +float8_e4m3fnuzx8 = func_gen(("Float8E4M3FNUZx8")) +float8_e4m3fnuzx16 = func_gen(("Float8E4M3FNUZx16")) +float8_e4m3fnuzx32 = func_gen(("Float8E4M3FNUZx32")) +float8_e4m3fnuzx64 = func_gen(("Float8E4M3FNUZx64")) + float8_e5m2 = func_gen(("Float8E5M2")) +float8_e5m2x2 = func_gen(("Float8E5M2x2")) float8_e5m2x4 = func_gen(("Float8E5M2x4")) float8_e5m2x8 = func_gen(("Float8E5M2x8")) float8_e5m2x16 = func_gen(("Float8E5M2x16")) float8_e5m2x32 = func_gen(("Float8E5M2x32")) float8_e5m2x64 = func_gen(("Float8E5M2x64")) +float8_e5m2fnuz = func_gen(("Float8E5M2FNUZ")) +float8_e5m2fnuzx2 = func_gen(("Float8E5M2FNUZx2")) +float8_e5m2fnuzx4 = func_gen(("Float8E5M2FNUZx4")) +float8_e5m2fnuzx8 = func_gen(("Float8E5M2FNUZx8")) +float8_e5m2fnuzx16 = func_gen(("Float8E5M2FNUZx16")) +float8_e5m2fnuzx32 = func_gen(("Float8E5M2FNUZx32")) +float8_e5m2fnuzx64 = func_gen(("Float8E5M2FNUZx64")) + +float8_e8m0fnu = func_gen(("Float8E8M0FNU")) +float8_e8m0fnux2 = func_gen(("Float8E8M0FNUx2")) +float8_e8m0fnux4 = func_gen(("Float8E8M0FNUx4")) +float8_e8m0fnux8 = func_gen(("Float8E8M0FNUx8")) +float8_e8m0fnux16 = func_gen(("Float8E8M0FNUx16")) +float8_e8m0fnux32 = func_gen(("Float8E8M0FNUx32")) +float8_e8m0fnux64 = func_gen(("Float8E8M0FNUx64")) + +# Float6 variants +float6_e2m3fn = func_gen(("Float6E2M3FN")) +float6_e2m3fnx2 = func_gen(("Float6E2M3FNx2")) +float6_e2m3fnx4 = func_gen(("Float6E2M3FNx4")) +float6_e2m3fnx8 = func_gen(("Float6E2M3FNx8")) +float6_e2m3fnx16 = func_gen(("Float6E2M3FNx16")) +float6_e2m3fnx32 = func_gen(("Float6E2M3FNx32")) +float6_e2m3fnx64 = func_gen(("Float6E2M3FNx64")) + +float6_e3m2fn = func_gen(("Float6E3M2FN")) +float6_e3m2fnx2 = func_gen(("Float6E3M2FNx2")) +float6_e3m2fnx4 = func_gen(("Float6E3M2FNx4")) +float6_e3m2fnx8 = func_gen(("Float6E3M2FNx8")) +float6_e3m2fnx16 = func_gen(("Float6E3M2FNx16")) +float6_e3m2fnx32 = func_gen(("Float6E3M2FNx32")) +float6_e3m2fnx64 = func_gen(("Float6E3M2FNx64")) + +# Float4 variants float4_e2m1fn = func_gen(("Float4E2M1FN")) float4_e2m1fnx2 = func_gen(("Float4E2M1FNx2")) float4_e2m1fnx4 = func_gen(("Float4E2M1FNx4")) @@ -1961,7 +2035,31 @@ def wrapped(*args, **kwargs): # pylint: enable=invalid-name -__all__ = [ +bases = [ + "float8_e3m4", + "float8_e4m3", + "float8_e4m3b11fnuz", + "float8_e4m3fn", + "float8_e4m3fnuz", + "float8_e5m2", + "float8_e5m2fnuz", + "float8_e8m0fnu", + "float6_e2m3fn", + "float6_e3m2fn", + "float4_e2m1fn", + "float16", + "float32", + "float64", +] +lanes = [1, 2, 4, 8, 16, 32, 64] + +float_types = [] +for base in bases: + for lane in lanes: + suffix = f"x{lane}" if lane != 1 else "" + float_types.append(f"{base}{suffix}") + +__all__ = float_types + [ "int8", "int16", "int32", @@ -2010,43 +2108,6 @@ def wrapped(*args, **kwargs): "uint16x64", "uint32x64", "uint64x64", - "float8_e4m3fn", - "float8_e5m2", - "float4_e2m1fn", - "float16", - "float32", - "float64", - "float4_e2m1fnx2", - "float8_e4m3fnx4", - "float8_e5m2x4", - "float4_e2m1fnx4", - "float16x4", - "float32x4", - "float64x4", - "float8_e4m3fnx8", - "float8_e5m2x8", - "float4_e2m1fnx8", - "float16x8", - "float32x8", - "float64x8", - "float8_e4m3fnx16", - "float8_e5m2x16", - "float4_e2m1fnx16", - "float16x16", - "float32x16", - "float64x16", - "float8_e4m3fnx32", - "float8_e5m2x32", - "float4_e2m1fnx32", - "float16x32", - "float32x32", - "float64x32", - "float8_e4m3fnx64", - "float8_e5m2x64", - "float4_e2m1fnx64", - "float16x64", - "float32x64", - "float64x64", "bfloat16", "buffer", "buffer_decl", diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 387572f6427b..fcfd8deeb11f 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -73,8 +73,8 @@ TVM_REGISTER_NODE_TYPE(IntImmNode); FloatImm::FloatImm(DataType dtype, double value, Span span) { ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar."; - ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float4() || - dtype.code() >= DataType::kCustomBegin) + ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float6() || + dtype.is_float4() || dtype.code() >= DataType::kCustomBegin) << "ValueError: FloatImm supports only float, but " << dtype << " was supplied."; // check range for float32 and float16 since they have specified range. @@ -94,18 +94,69 @@ FloatImm::FloatImm(DataType dtype, double value, Span span) { << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; ICHECK_LE(value, support::kMaxBFloat16) << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; - } else if (dtype.is_float8()) { - double bound = - (dtype.code() == DataType::kFloat8_e4m3fn) ? support::kMaxE4M3FN : support::kMaxE5M2; - ICHECK_GE(value, -bound) << "ValueError: Literal value " << value << " exceeds minimum of " + } else if (dtype.is_float8_e3m4() || dtype.is_float8_e4m3() || dtype.is_float8_e4m3b11fnuz() || + dtype.is_float8_e4m3fn() || dtype.is_float8_e4m3fnuz() || dtype.is_float8_e5m2() || + dtype.is_float8_e5m2fnuz() || dtype.is_float8_e8m0fnu()) { + double bound = 0.0; + bool nonneg = false; + + switch (dtype.code()) { + case DataType::TypeCode::kFloat8_e3m4: + bound = support::kMaxE3M4; + break; + case DataType::TypeCode::kFloat8_e4m3: + bound = support::kMaxE4M3; + break; + case DataType::TypeCode::kFloat8_e4m3b11fnuz: + bound = support::kMaxE4M3B11FNUZ; + nonneg = true; + break; + case DataType::TypeCode::kFloat8_e4m3fn: + bound = support::kMaxE4M3FN; + break; + case DataType::TypeCode::kFloat8_e4m3fnuz: + bound = support::kMaxE4M3FNUZ; + nonneg = true; + break; + case DataType::TypeCode::kFloat8_e5m2: + bound = support::kMaxE5M2; + break; + case DataType::TypeCode::kFloat8_e5m2fnuz: + bound = support::kMaxE5M2FNUZ; + nonneg = true; + break; + case DataType::TypeCode::kFloat8_e8m0fnu: + bound = support::kMaxE8M0FNU; + nonneg = true; + break; + default: + LOG(FATAL) << "Unhandled float8 type: " << dtype; + } + + if (nonneg) { + ICHECK_GE(value, 0) << "ValueError: Literal value " << value << " below zero for unsigned " + << dtype; + } else { + ICHECK_GE(value, -bound) << "ValueError: Literal value " << value << " below minimum of " + << dtype; + } + ICHECK_LE(value, bound) << "ValueError: Literal value " << value << " exceeds maximum of " + << dtype; + + } else if (dtype.is_float6_e2m3fn() || dtype.is_float6_e3m2fn()) { + double bound = (dtype.code() == DataType::TypeCode::kFloat6_e2m3fn) ? support::kMaxE2M3FN + : support::kMaxE3M2FN; + ICHECK_GE(value, -bound) << "ValueError: Literal value " << value << " below minimum of " << dtype; - ICHECK_LE(value, bound) << "ValueError: Literal vaule " << value << " exceeds maximum of " + ICHECK_LE(value, bound) << "ValueError: Literal value " << value << " exceeds maximum of " + << dtype; + + } else if (dtype.is_float4_e2m1fn()) { + double bound = support::kMaxE2M1FN; + ICHECK_GE(value, -bound) << "ValueError: Literal value " << value << " below minimum of " + << dtype; + ICHECK_LE(value, bound) << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; - } else if (dtype.is_float4()) { - ICHECK_GE(value, -support::kMaxE2M1FN) - << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; - ICHECK_LE(value, support::kMaxE2M1FN) - << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; } } ObjectPtr node = make_object(); diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc index 3e3145c32f5c..32155408fea4 100644 --- a/src/runtime/device_api.cc +++ b/src/runtime/device_api.cc @@ -111,8 +111,7 @@ size_t DeviceAPI::GetDataSize(const DLTensor& arr, Optional mem_scope) { for (int i = 0; i < arr.ndim; ++i) { size *= static_cast(arr.shape[i]); } - size *= (arr.dtype.bits * arr.dtype.lanes + 7) / 8; - return size; + return ffi::GetDataSize(size, arr.dtype); } LOG(FATAL) << "Device does not support physical mem computation with " << "specified memory scope: " << mem_scope.value(); diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 2bf56e876164..f03a83a929ec 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -46,6 +46,10 @@ inline void VerifyDataType(DLDataType dtype) { return; else if (dtype.bits == 4 && dtype.code == kDLInt) return; + else if (dtype.bits == 6 && dtype.code == DataType::kFloat6_e2m3fn) + return; + else if (dtype.bits == 6 && dtype.code == DataType::kFloat6_e3m2fn) + return; else if (dtype.bits == 4 && dtype.code == DataType::kFloat4_e2m1fn) return; else diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index da772f608579..2d61ca3e75f5 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -759,12 +759,42 @@ TVM_FFI_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt); TVM_FFI_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int); TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.BFloat16").set_body_typed(BFloat16); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3FN").set_body_typed(Float8E4M3FN); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E5M2").set_body_typed(Float8E5M2); TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.BFloat16", BFloat16); + +// Float8 variants +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E3M4").set_body_typed(Float8E3M4); +TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E3M4", Float8E3M4); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3").set_body_typed(Float8E4M3); +TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3", Float8E4M3); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3B11FNUZ") + .set_body_typed(Float8E4M3B11FNUZ); +TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3B11FNUZ", Float8E4M3B11FNUZ); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3FN").set_body_typed(Float8E4M3FN); TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3FNUZ").set_body_typed(Float8E4M3FNUZ); +TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FNUZ", Float8E4M3FNUZ); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E5M2").set_body_typed(Float8E5M2); TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E5M2FNUZ").set_body_typed(Float8E5M2FNUZ); +TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2FNUZ", Float8E5M2FNUZ); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E8M0FNU").set_body_typed(Float8E8M0FNU); +TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E8M0FNU", Float8E8M0FNU); + +// Float6 variants +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float6E2M3FN").set_body_typed(Float6E2M3FN); +TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float6E2M3FN", Float6E2M3FN); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float6E3M2FN").set_body_typed(Float6E3M2FN); +TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float6E3M2FN", Float6E3M2FN); + +// Float4 variant TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float4E2M1FN").set_body_typed(Float4E2M1FN); TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN); diff --git a/src/support/scalars.h b/src/support/scalars.h index 6d2a6e868363..d9f2d7c54316 100644 --- a/src/support/scalars.h +++ b/src/support/scalars.h @@ -61,13 +61,37 @@ constexpr double kMaxFloat16 = 65504.0; // See https://en.wikipedia.org/wiki/Bfloat16_floating-point_format constexpr double kMaxBFloat16 = 3.895313892515354759047080037148786688e38; -// 2^8 * (1 + 6/8) +// 2^15 * (1 + 3/4) // See https://arxiv.org/pdf/2209.05433.pdf -constexpr double kMaxE4M3FN = 448; +constexpr double kMaxE5M2 = 57344; // 2^15 * (1 + 3/4) +constexpr double kMaxE5M2FNUZ = 57344; + +// 2^8 * (1 + 6/8) // See https://arxiv.org/pdf/2209.05433.pdf -constexpr double kMaxE5M2 = 57344; +constexpr double kMaxE4M3FN = 448; + +// 2^8 * (1 + 6/8) +constexpr double kMaxE4M3 = 448; + +// 2^8 * (1 + 6/8) +constexpr double kMaxE4M3FNUZ = 448; + +// 2^4 * 1.875 +constexpr double kMaxE4M3B11FNUZ = 30; + +// 2^4 * 1.9375 +constexpr double kMaxE3M4 = 31; + +// 2^(255 - 127) +constexpr double kMaxE8M0FNU = 3.4028236692093846e38; + +// 2^2 * (1 + 7/8) +constexpr double kMaxE2M3FN = 7.5; + +// 2^4 * (1 + 3/4) +constexpr double kMaxE3M2FN = 28.0; // 2^2 * (1 + 1/2) constexpr double kMaxE2M1FN = 6.0; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 634c9c2b57a5..e9bcfa97fd01 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -579,8 +579,15 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { default: LOG(FATAL) << "do not support " << dtype; } - } else if (dtype.code() == DataType::kFloat8_e4m3fn || dtype.code() == DataType::kFloat8_e5m2) { + } else if (dtype.code() == DataType::kFloat8_e3m4 || dtype.code() == DataType::kFloat8_e4m3 || + dtype.code() == DataType::kFloat8_e4m3b11fnuz || + dtype.code() == DataType::kFloat8_e4m3fn || + dtype.code() == DataType::kFloat8_e4m3fnuz || dtype.code() == DataType::kFloat8_e5m2 || + dtype.code() == DataType::kFloat8_e5m2fnuz || + dtype.code() == DataType::kFloat8_e8m0fnu) { etype = llvm::Type::getInt8Ty(*ctx); + } else if (dtype.code() == DataType::kFloat6_e2m3fn || dtype.code() == DataType::kFloat6_e3m2fn) { + etype = llvm::Type::getIntNTy(*ctx, 6); } else if (dtype.code() == DataType::kFloat4_e2m1fn) { etype = llvm::Type::getIntNTy(*ctx, 4); } diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index c3014b11a5be..a2f868debb47 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -60,10 +60,13 @@ std::string GetFP8Type(DataType type) { } stream << "__nv_fp8"; std::string suffix; - if (type.code() == DataType::kFloat8_e4m3fn) { + if (type.code() == DataType::kFloat8_e4m3fn || type.code() == DataType::kFloat8_e4m3fnuz || + type.code() == DataType::kFloat8_e4m3 || type.code() == DataType::kFloat8_e4m3b11fnuz) { suffix = "_e4m3"; - } else if (type.code() == DataType::kFloat8_e5m2) { + } else if (type.code() == DataType::kFloat8_e5m2 || type.code() == DataType::kFloat8_e5m2fnuz) { suffix = "_e5m2"; + } else if (type.code() == DataType::kFloat8_e8m0fnu) { + suffix = "_e8m0"; } else { LOG(FATAL) << "Unsupported FP8 type in CUDA codegen"; } @@ -71,6 +74,36 @@ std::string GetFP8Type(DataType type) { return stream.str(); } +std::string GetFP6Type(DataType type) { + std::stringstream stream; + int32_t lanes = type.lanes(); + std::string vec; + if (type.is_scalar()) { + vec = ""; + } else if (lanes == 2) { + vec = "x2"; + } else if (lanes == 4) { + vec = "x4"; + } else if (lanes == 8) { + vec = "x8"; + } else if (lanes == 16) { + vec = "x16"; + } else { + LOG(FATAL) << "Only support scalar and vector types of width (2, 4) for FP6"; + } + stream << "__nv_fp6"; + std::string suffix; + if (type.code() == DataType::kFloat6_e2m3fn) { + suffix = "_e2m3"; + } else if (type.code() == DataType::kFloat6_e3m2fn) { + suffix = "_e3m2"; + } else { + LOG(FATAL) << "Unsupported FP6 type in CUDA codegen"; + } + stream << vec << suffix; + return stream.str(); +} + std::string GetFP4Type(DataType type) { std::stringstream stream; int32_t lanes = type.lanes(); @@ -81,15 +114,19 @@ std::string GetFP4Type(DataType type) { vec = "x2"; } else if (lanes == 4) { vec = "x4"; + } else if (lanes == 8) { + vec = "x8"; + } else if (lanes == 16) { + vec = "x16"; } else { - LOG(FATAL) << "Only support scalar and vector types of width (2, 4) for FP8"; + LOG(FATAL) << "Only support scalar and vector types of width (2, 4) for FP4"; } stream << "__nv_fp4"; std::string suffix; if (type.code() == DataType::kFloat4_e2m1fn) { suffix = "_e2m1"; } else { - LOG(FATAL) << "Unsupported FP8 type in CUDA codegen"; + LOG(FATAL) << "Unsupported FP4 type in CUDA codegen"; } stream << vec << suffix; return stream.str(); @@ -187,11 +224,38 @@ std::string CodeGenCUDA::Finish() { decl_stream << "using fp8_e5x4_t = __nv_fp8x4_e5m2;\n"; decl_stream << "struct fp8_e5x8_t {\n fp8_e5_t data[8]; \n};\n"; decl_stream << "struct fp8_e5x16_t {\n fp8_e5_t data[16]; \n};\n"; + decl_stream << "using fp8_e8_t = __nv_fp8_e8m0;\n"; + decl_stream << "using fp8_e8x2_t = __nv_fp8x2_e8m0;\n"; + decl_stream << "using fp8_e8x4_t = __nv_fp8x4_e8m0;\n"; + decl_stream << "struct fp8_e8x8_t {\n fp8_e8_t data[8]; \n};\n"; + decl_stream << "struct fp8_e8x16_t {\n fp8_e8_t data[16]; \n};\n"; decl_stream << "#endif\n\n"; } + + if (enable_fp6_) { + decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)\n"; + decl_stream << "#include \n"; + decl_stream << "using fp6_e2_t = __nv_fp6_e2m3;\n"; + decl_stream << "using fp6_e2x2_t = __nv_fp6x2_e2m3;\n"; + decl_stream << "using fp6_e2x4_t = __nv_fp6x4_e2m3;\n"; + decl_stream << "struct fp6_e2x8_t {\n fp6_e2_t data[8]; \n};\n"; + decl_stream << "struct fp6_e2x16_t {\n fp6_e2_t data[16]; \n};\n"; + decl_stream << "using fp6_e3_t = __nv_fp6_e3m2;\n"; + decl_stream << "using fp6_e3x2_t = __nv_fp6x2_e3m2;\n"; + decl_stream << "using fp6_e3x4_t = __nv_fp6x4_e3m2;\n"; + decl_stream << "struct fp6_e3x8_t {\n fp6_e3_t data[8]; \n};\n"; + decl_stream << "struct fp6_e3x16_t {\n fp6_e3_t data[16]; \n};\n"; + decl_stream << "#endif\n\n"; + } + if (enable_fp4_) { decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n"; decl_stream << "#include \n"; + decl_stream << "using fp4_e2_t = __nv_fp4_e2m1;\n"; + decl_stream << "using fp4_e2x2_t = __nv_fp4x2_e2m1;\n"; + decl_stream << "using fp4_e2x4_t = __nv_fp4x4_e2m1;\n"; + decl_stream << "struct fp4_e2x8_t {\n fp4_e2_t data[8]; \n};\n"; + decl_stream << "struct fp4_e2x16_t {\n fp4_e2_t data[16]; \n};\n"; decl_stream << "#endif\n\n"; } declare_vector_type_extensions(decl_stream, enable_fp16_, enable_bf16_, enable_fp8_, enable_fp4_); @@ -349,6 +413,14 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "uint" << t.lanes() / 4; } return; + } else if (t.is_float6()) { + enable_fp6_ = true; + if (t.lanes() <= 4) { + os << GetFP6Type(t); + } else { + fail = true; + } + return; } else if (t.is_float4()) { enable_fp4_ = true; if (t.lanes() <= 4) { @@ -744,9 +816,20 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { // Emit simple C-style type conversion. if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); - if (target_ty.code() == DataType::kFloat8_e4m3fn || target_ty.code() == DataType::kFloat8_e5m2 || - target_ty.code() == DataType::kFloat4_e2m1fn || from_ty.code() == DataType::kFloat8_e4m3fn || - from_ty.code() == DataType::kFloat8_e5m2 || from_ty.code() == DataType::kFloat4_e2m1fn) { + if (target_ty.code() == DataType::kFloat8_e3m4 || target_ty.code() == DataType::kFloat8_e4m3 || + target_ty.code() == DataType::kFloat8_e4m3b11fnuz || + target_ty.code() == DataType::kFloat8_e4m3fn || + target_ty.code() == DataType::kFloat8_e4m3fnuz || + target_ty.code() == DataType::kFloat8_e5m2 || + target_ty.code() == DataType::kFloat8_e5m2fnuz || + target_ty.code() == DataType::kFloat8_e8m0fnu || + target_ty.code() == DataType::kFloat4_e2m1fn || + + from_ty.code() == DataType::kFloat8_e3m4 || from_ty.code() == DataType::kFloat8_e4m3 || + from_ty.code() == DataType::kFloat8_e4m3b11fnuz || + from_ty.code() == DataType::kFloat8_e4m3fn || from_ty.code() == DataType::kFloat8_e4m3fnuz || + from_ty.code() == DataType::kFloat8_e5m2 || from_ty.code() == DataType::kFloat8_e5m2fnuz || + from_ty.code() == DataType::kFloat8_e8m0fnu || from_ty.code() == DataType::kFloat4_e2m1fn) { std::ostringstream val; if (target_ty.code() == DataType::kBFloat && target_ty.lanes() == 2) { val << "cast_to_nv_bfloat162(" << PrintExpr(op->value) << ")"; diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index ed5709ac12be..bfd5794f81f0 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -42,8 +42,8 @@ class CodeGenCUDA final : public CodeGenC { void Init(bool output_ssa); std::string Finish(); bool need_include_path() { - return (enable_fp16_ || enable_bf16_ || enable_int8_ || enable_fp8_ || enable_fp4_ || - need_math_constants_h_ || need_mma_h_); + return (enable_fp16_ || enable_bf16_ || enable_int8_ || enable_fp8_ || enable_fp6_ || + enable_fp4_ || need_math_constants_h_ || need_mma_h_); } // override behavior void PrintFuncPrefix(std::ostream& os) final; @@ -96,6 +96,8 @@ class CodeGenCUDA final : public CodeGenC { bool enable_bf16_{false}; // whether enable fp8 bool enable_fp8_{false}; + // whether enable fp6 + bool enable_fp6_{false}; // whether enable fp4 bool enable_fp4_{false}; // whether enable int8 diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index 039d89b93feb..f8519bc6c10e 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -454,6 +454,26 @@ struct __align__(8) half4_bfloat164 { (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16)); return result; } + __host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e8m0& fp8x4) { + __nv_fp8x2_e8m0 lo_part, hi_part; + lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF); + hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF); + TVec2 lo_half2 = static_cast(lo_part); + TVec2 hi_half2 = static_cast(hi_part); + x = reinterpret_cast(&lo_half2)[0]; + y = reinterpret_cast(&lo_half2)[1]; + z = reinterpret_cast(&hi_half2)[0]; + w = reinterpret_cast(&hi_half2)[1]; + } + __host__ __device__ explicit operator __nv_fp8x4_e8m0() const { + __nv_fp8x4_e8m0 result; + TVec2 lo_half2 = *reinterpret_cast(&x); + TVec2 hi_half2 = *reinterpret_cast(&z); + __nv_fp8x2_e8m0 lo_part(lo_half2), hi_part(hi_half2); + result.__x = + (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16)); + return result; + } )"; } if (enable_fp4) { @@ -519,6 +539,22 @@ __host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp8x2_e4m3& fp8 nv_bfloat16 y = nv_bfloat16(elem1); return nv_bfloat162(x, y); } +__host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp8x2_e5m2& fp8x2) { + __nv_fp8_e5m2 elem0, elem1; + elem0.__x = static_cast<__nv_fp8_storage_t>(fp8x2.__x & 0xFF); + elem1.__x = static_cast<__nv_fp8_storage_t>((fp8x2.__x >> 8) & 0xFF); + nv_bfloat16 x = nv_bfloat16(elem0); + nv_bfloat16 y = nv_bfloat16(elem1); + return nv_bfloat162(x, y); +} +__host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp8x2_e8m0& fp8x2) { + __nv_fp8_e8m0 elem0, elem1; + elem0.__x = static_cast<__nv_fp8_storage_t>(fp8x2.__x & 0xFF); + elem1.__x = static_cast<__nv_fp8_storage_t>((fp8x2.__x >> 8) & 0xFF); + nv_bfloat16 x = nv_bfloat16(elem0); + nv_bfloat16 y = nv_bfloat16(elem1); + return nv_bfloat162(x, y); +} )"; } } @@ -544,6 +580,16 @@ __device__ __nv_fp8x4_e4m3 make___nv_fp8x4_e4m3(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b result.__x = (a.__x) | (b.__x << 8) | (c.__x << 16) | (d.__x << 24); return result; } +__device__ __nv_fp8x2_e8m0 make___nv_fp8x2_e8m0(__nv_fp8_e8m0 x, __nv_fp8_e8m0 y) { + __nv_fp8x2_e8m0 result; + result.__x = (x.__x) | (y.__x << 8); + return result; +} +__device__ __nv_fp8x4_e8m0 make___nv_fp8x4_e8m0(__nv_fp8_e8m0 a, __nv_fp8_e8m0 b, __nv_fp8_e8m0 c, __nv_fp8_e8m0 d) { + __nv_fp8x4_e8m0 result; + result.__x = (a.__x) | (b.__x << 8) | (c.__x << 16) | (d.__x << 24); + return result; +} )"; } if (enable_fp4) { diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 341a96cae697..41e090c58542 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -201,6 +201,12 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } else if (ltype.is_float8() && !rtype.is_float8()) { // Cast int->float8 for rhs when lhs is a float8 rhs = cast(ltype, rhs); + } else if (!ltype.is_float6() && rtype.is_float6()) { + // Cast int->float6 for lhs when rhs is a float6 + lhs = cast(rtype, lhs); + } else if (ltype.is_float6() && !rtype.is_float6()) { + // Cast int->float6 for rhs when lhs is a float6 + rhs = cast(ltype, rhs); } else if (!ltype.is_float4() && rtype.is_float4()) { // Cast int->float4 for lhs when rhs is a float4 lhs = cast(rtype, lhs); @@ -275,8 +281,25 @@ PrimExpr max_value(const DataType& dtype, Span span) { // according to https://arxiv.org/pdf/2209.05433.pdf if (dtype.code() == DataType::TypeCode::kFloat8_e5m2) { return FloatImm(dtype, 57344.0, span); + } else if (dtype.code() == DataType::TypeCode::kFloat8_e5m2fnuz) { + return FloatImm(dtype, 57344.0, span); } else if (dtype.code() == DataType::TypeCode::kFloat8_e4m3fn) { return FloatImm(dtype, 448.0, span); + } else if (dtype.code() == DataType::TypeCode::kFloat8_e4m3fnuz || + dtype.code() == DataType::TypeCode::kFloat8_e4m3) { + return FloatImm(dtype, 448.0, span); + } else if (dtype.code() == DataType::TypeCode::kFloat8_e4m3b11fnuz) { + return FloatImm(dtype, 30.0, span); + } else if (dtype.code() == DataType::TypeCode::kFloat8_e3m4) { + return FloatImm(dtype, 31.0, span); + } else if (dtype.code() == DataType::TypeCode::kFloat8_e8m0fnu) { + return FloatImm(dtype, 3.4028236692093846e+38, span); + } + } else if (dtype.is_float6()) { + if (dtype.code() == DataType::TypeCode::kFloat6_e2m3fn) { + return FloatImm(dtype, 7.5, span); + } else if (dtype.code() == DataType::TypeCode::kFloat6_e3m2fn) { + return FloatImm(dtype, 28.0, span); } } else if (dtype.is_float4()) { return FloatImm(dtype, 6.0, span); @@ -318,8 +341,26 @@ PrimExpr min_value(const DataType& dtype, Span span) { // according to https://arxiv.org/pdf/2209.05433.pdf if (dtype.code() == DataType::TypeCode::kFloat8_e5m2) { return FloatImm(dtype, -57344.0, span); + } else if (dtype.code() == DataType::TypeCode::kFloat8_e5m2fnuz) { + return FloatImm(dtype, 0.0, span); } else if (dtype.code() == DataType::TypeCode::kFloat8_e4m3fn) { return FloatImm(dtype, -448.0, span); + } else if (dtype.code() == DataType::TypeCode::kFloat8_e4m3fnuz) { + return FloatImm(dtype, 0.0, span); + } else if (dtype.code() == DataType::TypeCode::kFloat8_e4m3) { + return FloatImm(dtype, -448.0, span); + } else if (dtype.code() == DataType::TypeCode::kFloat8_e4m3b11fnuz) { + return FloatImm(dtype, 0.0, span); + } else if (dtype.code() == DataType::TypeCode::kFloat8_e3m4) { + return FloatImm(dtype, -31.0, span); + } else if (dtype.code() == DataType::TypeCode::kFloat8_e8m0fnu) { + return FloatImm(dtype, 0.0, span); + } + } else if (dtype.is_float6()) { + if (dtype.code() == DataType::TypeCode::kFloat6_e2m3fn) { + return FloatImm(dtype, -7.5, span); + } else if (dtype.code() == DataType::TypeCode::kFloat6_e3m2fn) { + return FloatImm(dtype, -28.0, span); } } else if (dtype.is_float4()) { return FloatImm(dtype, -6.0, span); diff --git a/src/tir/transforms/dtype_conversion.cc b/src/tir/transforms/dtype_conversion.cc index dfb0a5a63114..85341981d3c0 100644 --- a/src/tir/transforms/dtype_conversion.cc +++ b/src/tir/transforms/dtype_conversion.cc @@ -39,7 +39,8 @@ PrimExpr DTypeConversion(PrimExpr src_value, DataType tgt_dtype, RoundingMode ro CHECK_EQ(src_dtype.lanes(), tgt_dtype.lanes()) << "The lanes for data type for source value must matches the target datatype."; auto is_floating_point = [](DataType dtype) { - return dtype.is_float() || dtype.is_float8() || dtype.is_bfloat16() || dtype.is_float4(); + return dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float6() || + dtype.is_float4(); }; // Both source dtype and target dtype should be floating point. CHECK(is_floating_point(src_dtype) && is_floating_point(tgt_dtype)); diff --git a/src/tir/transforms/dtype_conversion.h b/src/tir/transforms/dtype_conversion.h index a0ed6b5f6d86..bd6f34b3ac17 100644 --- a/src/tir/transforms/dtype_conversion.h +++ b/src/tir/transforms/dtype_conversion.h @@ -99,7 +99,8 @@ class FloatConfig { * \return The FloatConfig class containing internal floating point representation. */ static FloatConfig FromDataType(DataType dtype) { - CHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float4()) + CHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float6() || + dtype.is_float4()) << "FloatConfig is only applicable to floating point data types, got " << dtype << " instead."; if (dtype.is_float()) { @@ -121,12 +122,43 @@ class FloatConfig { // NVIDIA/Arm/Intel's FP8 formats for Deep Learning // Reference: https://arxiv.org/abs/2209.05433 switch (dtype.code()) { + case DataType::kFloat8_e3m4: + // E3M4 format, not consistent with IEEE-754 + return FloatConfig(3, 4, 3, InftyStyle::kNone, NaNStyle::kAllOnes); + case DataType::kFloat8_e4m3: + // E4M3 format, not consistent with IEEE-754 + return FloatConfig(4, 3, 7, InftyStyle::kNone, NaNStyle::kAllOnes); + case DataType::kFloat8_e4m3b11fnuz: + // E4M3 variant with b11 encoding, not consistent with IEEE-754 + return FloatConfig(4, 3, 7, InftyStyle::kNone, NaNStyle::kAllOnes); case DataType::kFloat8_e4m3fn: // E4M3 format, not consistent with IEEE-754 return FloatConfig(4, 3, 7, InftyStyle::kNone, NaNStyle::kAllOnes); - default: - // E5M2 format, consistent with IEEE-754 + case DataType::kFloat8_e4m3fnuz: + // UE4M3 format, not consistent with IEEE-754 + return FloatConfig(4, 3, 7, InftyStyle::kNone, NaNStyle::kAllOnes); + case DataType::kFloat8_e5m2: + // UE5M2 format, consistent with IEEE-754 return FloatConfig(5, 2, 15, InftyStyle::kIEEE, NaNStyle::kIEEE); + case DataType::kFloat8_e5m2fnuz: + // UE5M2 format, not consistent with IEEE-754 + return FloatConfig(5, 2, 15, InftyStyle::kNone, NaNStyle::kAllOnes); + case DataType::kFloat8_e8m0fnu: + // UE8M0 format, not consistent with IEEE-754 + return FloatConfig(8, 0, 127, InftyStyle::kNone, NaNStyle::kAllOnes); + default: + LOG(FATAL) << "Unknown float8 variant: " << dtype; + } + } else if (dtype.is_float6()) { // float6 + switch (dtype.code()) { + case DataType::kFloat6_e2m3fn: + // E2M3 format, not consistent with IEEE-754 + return FloatConfig(2, 3, 1, InftyStyle::kNone, NaNStyle::kNone); + case DataType::kFloat6_e3m2fn: + // E3M2 format, not consistent with IEEE-754 + return FloatConfig(3, 2, 3, InftyStyle::kNone, NaNStyle::kNone); + default: + LOG(FATAL) << "Unknown float6 variant: " << dtype; } } else { // float4 diff --git a/tests/python/codegen/test_target_codegen_cuda_fp4.py b/tests/python/codegen/test_target_codegen_cuda_fp4.py index 70e20b32d88b..364f9461c2f9 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp4.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp4.py @@ -18,6 +18,7 @@ from itertools import product import numpy as np +import pytest import tvm import tvm.testing @@ -28,14 +29,11 @@ except ImportError: ml_dtypes = None -native_dtype, promoted_dtype = tvm.testing.parameters( - ("float4_e2m1fnx2", "float32x2"), - ("float4_e2m1fnx2", "float16x2"), -) - +@pytest.mark.parametrize("promoted_dtype", ["float32x2", "float16x2"]) @tvm.testing.requires_cuda_compute_version(10) -def test_e2m1_vector_conversions(native_dtype, promoted_dtype): +def test_e2m1_vector_conversions(promoted_dtype): + native_dtype = "float4_e2m1fnx2" vector_length = 64 @T.prim_func @@ -45,7 +43,6 @@ def add( C: T.Buffer((vector_length,), native_dtype), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): for i in range(vector_length): with T.block("C"): v_i = T.axis.spatial(vector_length, i) @@ -92,125 +89,6 @@ def add( ) -@tvm.testing.requires_cuda_compute_version(10) -def test_e2m1_schedule_vectorize(): - native_dtype = "float4_e2m1fn" - n = 128 - - dev = tvm.device("cuda", 0) - target = tvm.target.Target.from_device(dev) - for promoted_dtype, vector_length in product( - ["float16", "bfloat16", "float32"], - [1, 2, 4], - ): - - @T.prim_func - def add( - A: T.Buffer((n,), native_dtype), - B: T.Buffer((n,), native_dtype), - C: T.Buffer((n,), native_dtype), - ): - T.func_attr({"tir.noalias": True}) - for i in range(n): - with T.block("C"): - v_i = T.axis.spatial(n, i) - T.reads(A[v_i], B[v_i]) - T.writes(C[v_i]) - C[v_i] = T.Cast( - native_dtype, - T.Cast(promoted_dtype, A[v_i]) + T.Cast(promoted_dtype, B[v_i]), - ) - - sch = tvm.tir.Schedule(add) - block = sch.get_block("C") - b = sch.get_loops(block) - bx, tx, vec = sch.split(b[0], factors=[None, 32, vector_length]) - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") - sch.vectorize(vec) - - fadd = tvm.compile(sch.mod, target=target) - - numpytype = "float4_e2m1fn" - promoted_base_dtype = promoted_dtype - - a_np = np.random.uniform(low=-6, high=6, size=(n,)).astype(numpytype) - a = tvm.nd.empty(shape=(n,), dtype=native_dtype, device=dev) - a.copyfrom(a_np) - b_np = np.random.uniform(low=-6, high=6, size=(n,)).astype(numpytype) - b = tvm.nd.empty(shape=(n,), dtype=native_dtype, device=dev) - b.copyfrom(b_np) - c = tvm.nd.empty(shape=(n,), dtype=native_dtype, device=dev) - fadd(a, b, c) - - if promoted_base_dtype != "bfloat16": - tvm.testing.assert_allclose( - c.numpy().astype(promoted_base_dtype), (a_np + b_np).astype(promoted_base_dtype) - ) - else: - # assert_allclose with bfloat16 throws an error here. - # Thus we convert bfloat16 to float32 for comparison. - tvm.testing.assert_allclose( - c.numpy().astype(promoted_base_dtype).astype("float32"), - (a_np + b_np).astype(promoted_base_dtype).astype("float32"), - ) - - -@tvm.testing.requires_cuda_compute_version(10) -def test_e2m1_reinterpret(): - n = 128 - - dev = tvm.device("cuda", 0) - target = tvm.target.Target.from_device(dev) - - def get_reinterpret_mod(src_dtype, dst_dtype, vector_length): - @T.prim_func - def reinterpret( - A: T.Buffer((n,), src_dtype), - B: T.Buffer((n,), dst_dtype), - ): - T.func_attr({"tir.noalias": True}) - for i in range(n): - with T.block("C"): - v_i = T.axis.spatial(n, i) - T.reads(A[v_i]) - T.writes(B[v_i]) - B[v_i] = T.reinterpret(dst_dtype, A[v_i]) - - sch = tvm.tir.Schedule(reinterpret) - block = sch.get_block("C") - b = sch.get_loops(block) - bx, tx, vec = sch.split(b[0], factors=[None, 32, vector_length]) - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") - sch.vectorize(vec) - return sch.mod - - # Part 1. reinterpret float4_e2m1fn to uint8 - for vector_length in [1, 2, 4]: - mod = get_reinterpret_mod("float4_e2m1fn", "uint8", vector_length) - f = tvm.compile(mod, target=target) - a_np = np.random.uniform(low=-6, high=6, size=(n,)).astype("float4_e2m1fn") - a = tvm.nd.empty(shape=(n,), dtype="float4_e2m1fn", device=dev) - a.copyfrom(a_np) - b = tvm.nd.empty(shape=(n,), dtype="uint8", device=dev) - f(a, b) - tvm.testing.assert_allclose(b.numpy(), a_np.view("uint8")) - - # Part 2. reinterpret uint8 to float4_e2m1fn - for vector_length in [1, 2, 4]: - mod = get_reinterpret_mod("uint8", "float4_e2m1fn", vector_length) - f = tvm.compile(mod, target=target) - a_np = np.random.uniform(low=-6, high=6, size=(n,)).astype("uint8") - a = tvm.nd.empty(shape=(n,), dtype="uint8", device=dev) - a.copyfrom(a_np) - b = tvm.nd.empty(shape=(n,), dtype="float4_e2m1fn", device=dev) - f(a, b) - tvm.testing.assert_allclose( - b.numpy().astype("float32"), a_np.view("float4_e2m1fn").astype("float32") - ) - - @tvm.testing.requires_cuda_compute_version(10) def test_e2m1_dequantize(): n = 128 diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index da6fb918444d..139347a5b9d9 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -36,9 +36,18 @@ ml_dtypes = None -@tvm.testing.requires_cuda_compute_version(8, 9) -def test_e4m3_conversions(): - dtype = "float8_e4m3fn" +@pytest.mark.parametrize( + "input", + [ + ("float8_e4m3fn", "__nv_fp8_e4m3"), + ("float8_e4m3fnuz", "__nv_fp8_e4m3"), + ("float8_e5m2", "__nv_fp8_e5m2"), + ("float8_e5m2fnuz", "__nv_fp8_e5m2"), + ], +) +@tvm.testing.requires_cuda_compute_version(10) +def test_fp8_conversions(input): + dtype, nv_dtype = input @T.prim_func def add( @@ -47,7 +56,6 @@ def add( C: T.Buffer((64,), dtype), ): T.func_attr({"tir.noalias": True}) - # with T.block("root"): for i in range(64): with T.block("C"): v_i = T.axis.spatial(64, i) @@ -66,14 +74,13 @@ def add( fadd = tvm.tir.build(sch.mod, target=target) cuda_src = fadd.imported_modules[0].get_source() - assert "__nv_fp8_e4m3" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found in generated CUDA" + assert nv_dtype in cuda_src, f"{nv_dtype} datatype not found in generated CUDA" dev = tvm.device(target, 0) - numpytype = "float8_e4m3fn" - a = tvm.nd.array(np.random.uniform(low=0, high=5, size=64).astype(numpytype), dev) - b = tvm.nd.array(np.random.uniform(low=0, high=5, size=64).astype(numpytype), dev) - c = tvm.nd.array(np.zeros(64, dtype=numpytype), dev) + a = tvm.nd.array(np.random.uniform(low=0, high=5, size=64).astype(dtype), dev) + b = tvm.nd.array(np.random.uniform(low=0, high=5, size=64).astype(dtype), dev) + c = tvm.nd.array(np.zeros(64, dtype=dtype), dev) fadd(a, b, c) tvm.testing.assert_allclose( @@ -81,11 +88,15 @@ def add( ) -@tvm.testing.requires_cuda_compute_version(8, 9) -def test_e4m3_packing(): +@pytest.mark.parametrize( + "dtype", + ["float8_e4m3fn", "float8_e4m3fnuz", "float8_e5m2", "float8_e5m2fnuz", "float8_e8m0fnu"], +) +@tvm.testing.requires_cuda_compute_version(10) +def test_fp8_packing(dtype): length = 64 vector_length = 4 - native_dtype, packed_dtype = ("float8_e4m3fnx4", "uint32") + native_dtype, packed_dtype = (f"{dtype}x{vector_length}", "uint32") @T.prim_func def add( @@ -124,9 +135,8 @@ def add( f = tvm.compile(sch.mod, target=target) dev = tvm.device(target, 0) - numpytype = "float8_e4m3fn" np_shape = (length, vector_length) - a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype) + a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(dtype) a = tvm.nd.empty(shape=(length,), dtype=native_dtype, device=dev) r = tvm.nd.empty(shape=(length,), dtype=packed_dtype, device=dev) b = tvm.nd.empty(shape=(length,), dtype=native_dtype, device=dev) @@ -135,19 +145,25 @@ def add( tvm.testing.assert_allclose(a.numpy().astype("float16"), b.numpy().astype("float16")) -native_dtype, promoted_dtype = tvm.testing.parameters( - ("float8_e4m3fn", "float32"), - ("float8_e4m3fn", "float16"), - ("float8_e4m3fnx2", "float32x2"), - ("float8_e4m3fnx2", "float16x2"), - ("float8_e4m3fnx4", "float32x4"), +native_dtype, promoted_dtype, numpytype = tvm.testing.parameters( + ("float8_e4m3fn", "float32", "float8_e4m3fn"), + ("float8_e4m3fn", "float16", "float8_e4m3fn"), + ("float8_e4m3fnx2", "float32x2", "float8_e4m3fn"), + ("float8_e4m3fnx2", "float16x2", "float8_e4m3fn"), + ("float8_e4m3fnx4", "float32x4", "float8_e4m3fn"), # Supported via half4 vector type extension in codegen - ("float8_e4m3fnx4", "float16x4"), + ("float8_e4m3fnx4", "float16x4", "float8_e4m3fn"), + ("float8_e5m2", "float32", "float8_e5m2"), + ("float8_e5m2", "float16", "float8_e5m2"), + ("float8_e5m2x2", "float32x2", "float8_e5m2"), + ("float8_e5m2x2", "float16x2", "float8_e5m2"), + ("float8_e5m2x4", "float32x4", "float8_e5m2"), + ("float8_e5m2x4", "float16x4", "float8_e5m2"), ) -@tvm.testing.requires_cuda_compute_version(8, 9) -def test_e4m3_vector_conversions(native_dtype, promoted_dtype): +@tvm.testing.requires_cuda_compute_version(10) +def test_fp8_vector_conversions(native_dtype, promoted_dtype, numpytype): vector_length = 64 @T.prim_func @@ -179,7 +195,6 @@ def add( cuda_src = fadd.imported_modules[0].get_source() dev = tvm.device(target, 0) - numpytype = "float8_e4m3fn" if "x" in native_dtype: lanes = int(native_dtype.split("x")[-1]) else: @@ -801,8 +816,8 @@ def test_main(self, weight_shape, model_dtype, target_str, compiled_functions): tvm.testing.assert_allclose(weight_np, dequant_weight_np, atol=10, rtol=5e-2) -@tvm.testing.requires_cuda_compute_version(8, 9) -@pytest.mark.parametrize("dtype", ["float8_e5m2", "float8_e4m3fn"]) +@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.parametrize("dtype", ["float8_e5m2", "float8_e4m3fn", "float8_e8m0fnu"]) def test_const(dtype): @T.prim_func def func(A: T.Buffer((4,), dtype)) -> None: @@ -959,40 +974,42 @@ def _pipeline(mod: tvm.ir.IRModule) -> tvm.ir.IRModule: vm["main"](x, indptr, weight, scale) +@pytest.mark.parametrize("vec_length", [2, 4]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @tvm.testing.requires_cuda_compute_version(8, 9) -def test_fp8_fp16_bf16_vectorize_arith(): - for vec_length, dtype in product([2, 4], ["float16", "bfloat16"]): - - @T.prim_func - def func_vectorize( - A: T.Buffer((128,), "float8_e4m3fn"), - B: T.Buffer((128,), dtype), - C: T.Buffer((128,), dtype), - ) -> None: - for i in T.serial(128): - with T.block("compute"): - vi = T.axis.remap("S", [i]) - C[vi] = (A[vi].astype(dtype) * B[vi]) + T.bfloat16(3.0) - - sch = tir.Schedule(func_vectorize) - (l,) = sch.get_loops(sch.get_block("compute")) - lo, li = sch.split(l, [None, vec_length]) - sch.bind(lo, "threadIdx.x") - sch.vectorize(li) - - device = tvm.cuda() - target = tvm.target.Target.from_device(device) - f = tir.build(sch.mod, target=target) - - a_np = np.random.rand(128).astype("float8_e4m3fn") - b_np = np.random.rand(128).astype(dtype) - c_np = (a_np.astype(dtype) * b_np) + 3 - a_tvm = tvm.nd.array(a_np, device=device) - b_tvm = tvm.nd.array(b_np, device=device) - c_tvm = tvm.nd.empty((128,), dtype=dtype, device=device) - f(a_tvm, b_tvm, c_tvm) - c_tvm = c_tvm.numpy() - np.testing.assert_allclose(c_tvm, c_np, atol=1e-3, rtol=1e-3) +def test_fp8_fp16_bf16_vectorize_arith(vec_length, dtype): + @T.prim_func + def func_vectorize( + A: T.Buffer((128,), "float8_e4m3fn"), + B: T.Buffer((128,), dtype), + C: T.Buffer((128,), dtype), + ) -> None: + for i in T.serial(128): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + C[vi] = (A[vi].astype(dtype) * B[vi]) + T.bfloat16(3.0) + + sch = tir.Schedule(func_vectorize) + (l,) = sch.get_loops(sch.get_block("compute")) + lo, li = sch.split(l, [None, vec_length]) + sch.bind(lo, "threadIdx.x") + sch.vectorize(li) + + device = tvm.cuda() + target = tvm.target.Target.from_device(device) + f = tir.build(sch.mod, target=target) + + a_np = np.random.rand(128).astype("float8_e4m3fn") + b_np = np.random.rand(128).astype(dtype) + c_np = (a_np.astype(dtype) * b_np) + 3 + a_tvm = tvm.nd.array(a_np, device=device) + b_tvm = tvm.nd.array(b_np, device=device) + c_tvm = tvm.nd.empty((128,), dtype=dtype, device=device) + f(a_tvm, b_tvm, c_tvm) + c_tvm = c_tvm.numpy() + np.testing.assert_allclose( + c_tvm.astype(np.float32), c_np.astype(np.float32), atol=5e-1, rtol=1e-2 + ) if __name__ == "__main__": diff --git a/tests/python/ffi/test_dtype.py b/tests/python/ffi/test_dtype.py index 2758edf9d6a3..332d0e1827d8 100644 --- a/tests/python/ffi/test_dtype.py +++ b/tests/python/ffi/test_dtype.py @@ -18,6 +18,8 @@ import pytest import pickle import numpy as np +import tvm +import tvm.testing from tvm import ffi as tvm_ffi @@ -31,7 +33,15 @@ def test_dtype(): @pytest.mark.parametrize( "dtype_str, expected_size", - [("float32", 4), ("float32x4", 16), ("float8_e5m2x4", 4), ("uint8", 1)], + [ + ("float32", 4), + ("float32x4", 16), + ("float8_e5m2x4", 4), + ("float6_e2m3fnx4", 3), + ("float4_e2m1fnx4", 2), + ("uint8", 1), + ("bool", 1), + ], ) def test_dtype_itemsize(dtype_str, expected_size): dtype = tvm_ffi.dtype(dtype_str) @@ -46,7 +56,15 @@ def test_dtype_itemmize_error(dtype_str): @pytest.mark.parametrize( "dtype_str", - ["float32", "float32x4", "float8_e5m2x4", "uint8"], + [ + "float32", + "float32x4", + "float8_e5m2x4", + "float6_e2m3fnx4", + "float4_e2m1fnx4", + "uint8", + "bool", + ], ) def test_dtype_pickle(dtype_str): dtype = tvm_ffi.dtype(dtype_str) @@ -56,9 +74,14 @@ def test_dtype_pickle(dtype_str): assert dtype_pickled.lanes == dtype.lanes -def test_dtype_with_lanes(): - dtype = tvm_ffi.dtype("float32") +@pytest.mark.parametrize("dtype_str", ["float32", "bool"]) +def test_dtype_with_lanes(dtype_str): + dtype = tvm_ffi.dtype(dtype_str) dtype_with_lanes = dtype.with_lanes(4) assert dtype_with_lanes.type_code == dtype.type_code assert dtype_with_lanes.bits == dtype.bits assert dtype_with_lanes.lanes == 4 + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/ir/test_datatype_nv_fp4.py b/tests/python/ir/test_datatype_nv_fp4.py new file mode 100644 index 000000000000..85047fc4a5fd --- /dev/null +++ b/tests/python/ir/test_datatype_nv_fp4.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np + +import tvm +import tvm.testing +import tvm.tir as tir +from tvm import te +from tvm.script import tir as T + +try: + from ml_dtypes import float4_e2m1fn +except ImportError: + float4_e2m1fn = None + + +np_dtype, dtype_str = tvm.testing.parameters((float4_e2m1fn, "float4_e2m1fn")) + + +def test_create_nv_fp4_nd_array(np_dtype, dtype_str): + if np_dtype is None: + """Skip test if ml_dtypes is not installed""" + return + x = np.random.rand(128, 128).astype(np_dtype) + x_nd = tvm.nd.array(x) + assert x_nd.dtype == dtype_str + np.testing.assert_equal(x_nd.numpy(), x) + + +def test_nv_fp4_buffer(np_dtype, dtype_str): + m = te.size_var("m") + n = te.size_var("n") + A = tvm.tir.decl_buffer((m, n), dtype_str) + assert A.dtype == dtype_str + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/ir/test_datatype_nv_fp8.py b/tests/python/ir/test_datatype_nv_fp8.py index 72cdfb469d43..d27cc0314328 100644 --- a/tests/python/ir/test_datatype_nv_fp8.py +++ b/tests/python/ir/test_datatype_nv_fp8.py @@ -23,10 +23,19 @@ from tvm.script import tir as T try: - from ml_dtypes import float8_e4m3fn as float8_e4m3fn - from ml_dtypes import float8_e5m2 as float8_e5m2 + from ml_dtypes import ( + float8_e3m4, + float8_e4m3, + float8_e4m3b11fnuz, + float8_e4m3fn, + float8_e4m3fnuz, + float8_e5m2, + float8_e5m2fnuz, + float8_e8m0fnu, + ) except ImportError: - float8_e4m3fn, float8_e5m2 = None, None + float8_e3m4 = float8_e4m3 = float8_e4m3b11fnuz = float8_e4m3fn = None + float8_e4m3fnuz = float8_e5m2 = float8_e5m2fnuz = float8_e8m0fnu = None def fp8_unary(dtype: str): @@ -60,7 +69,14 @@ def func( np_dtype, dtype_str = tvm.testing.parameters( - (float8_e4m3fn, "float8_e4m3fn"), (float8_e5m2, "float8_e5m2") + (float8_e3m4, "float8_e3m4"), + (float8_e4m3, "float8_e4m3"), + (float8_e4m3b11fnuz, "float8_e4m3b11fnuz"), + (float8_e4m3fn, "float8_e4m3fn"), + (float8_e4m3fnuz, "float8_e4m3fnuz"), + (float8_e5m2, "float8_e5m2"), + (float8_e5m2fnuz, "float8_e5m2fnuz"), + (float8_e8m0fnu, "float8_e8m0fnu"), ) @@ -71,6 +87,7 @@ def test_create_nv_fp8_nd_array(np_dtype, dtype_str): x = np.random.rand(128, 128).astype(np_dtype) x_nd = tvm.nd.array(x) assert x_nd.dtype == dtype_str + np.testing.assert_equal(x_nd.numpy(), x) def test_fp8_unary_op(np_dtype, dtype_str): @@ -80,6 +97,9 @@ def test_fp8_unary_op(np_dtype, dtype_str): if np_dtype is None: """Skip test if ml_dtypes is not installed""" return + if dtype_str in ["float8_e8m0fnu", "float8_e4m3b11fnuz", "float8_e4m3fnuz", "float8_e5m2fnuz"]: + # float8_e8m0fnu does not support arithmetic operations, and unsigned arithmetic is not tested here + return f = tvm.compile(func, target="llvm") a = np.random.randn(128).astype(np_dtype) @@ -93,6 +113,9 @@ def test_fp8_unary_op(np_dtype, dtype_str): map(lambda _: tvm.nd.array(_), [a, b, a_add_b, a_sub_b, a_mul_b, a_fp32, a_roundtrip]) ) f(*args) + expected_a_fp32 = a.astype(np.float32) + expected_a_roundtrip = expected_a_fp32.astype(np_dtype) + np.testing.assert_equal(args[6].numpy(), expected_a_roundtrip) def test_nv_fp8_buffer(np_dtype, dtype_str): diff --git a/tests/python/ir/test_dtype.py b/tests/python/ir/test_dtype.py index cfa39fda0073..ddacd8583511 100644 --- a/tests/python/ir/test_dtype.py +++ b/tests/python/ir/test_dtype.py @@ -24,7 +24,14 @@ @pytest.mark.parametrize( "dtype_str, expected_size", - [("float32", 4), ("float32x4", 16), ("float8_e5m2x4", 4), ("uint8", 1)], + [ + ("float32", 4), + ("float32x4", 16), + ("float8_e5m2x4", 4), + ("float6_e2m3fnx4", 3), + ("float4_e2m1fnx4", 2), + ("uint8", 1), + ], ) def test_dtype_itemsize(dtype_str, expected_size): dtype = DataType(dtype_str) diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index 21d8fc942251..e03cb6c9d583 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -908,25 +908,31 @@ def func(): _assert_print(func, expected_output) -@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"]) -def test_float8(dtype): +CUSTOM_FLOAT_DTYPES = [ + # Float8 variants + "float8_e3m4", + "float8_e4m3", + "float8_e4m3b11fnuz", + "float8_e4m3fn", + "float8_e4m3fnuz", + "float8_e5m2", + "float8_e5m2fnuz", + "float8_e8m0fnu", + # Float6 variants + "float6_e2m3fn", + "float6_e3m2fn", + # Float4 variant + "float4_e2m1fn", +] + + +@pytest.mark.parametrize("dtype", CUSTOM_FLOAT_DTYPES) +def test_custom_float_types(dtype): from tvm.script import tir as T - def get_func(dtype): - if dtype == "float8_e4m3fn": - - @T.prim_func - def func(): - T.evaluate(T.float8_e4m3fn(0.0)) - - return func - elif dtype == "float8_e5m2": - - @T.prim_func - def func(): - T.evaluate(T.float8_e5m2(0.0)) - - return func + @T.prim_func() + def func(): + T.evaluate(getattr(T, dtype)(0.0)) expected_output = f""" # from tvm.script import tir as T @@ -934,8 +940,7 @@ def func(): @T.prim_func def func(): T.evaluate(T.{dtype}(0.0)) - """ - func = get_func(dtype) +""" _assert_print(func, expected_output) From 8d7c83cd624e813b2cdf913f1113920a2b99be61 Mon Sep 17 00:00:00 2001 From: Kathryn-cat Date: Mon, 2 Jun 2025 15:09:05 -0400 Subject: [PATCH 2/4] change NV naming; prevent CUDA codegen on unsupported dtypes Co-authored-by: DerrickYLJ --- include/tvm/runtime/data_type.h | 48 +++++++++---------- include/tvm/script/ir_builder/tir/ir.h | 27 +++++------ src/relax/op/tensor/qdq.cc | 7 ++- src/target/source/codegen_cuda.cc | 5 +- src/tir/transforms/vectorize_loop.cc | 4 +- .../codegen/test_target_codegen_cuda_fp8.py | 4 +- 6 files changed, 45 insertions(+), 50 deletions(-) diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 9e9bcddc2957..6d8f9e09bff5 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -295,83 +295,83 @@ class DataType { */ static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); } /*! - * \brief Construct NV float8 e3m4 datatype. + * \brief Construct float8 e3m4 datatype. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType NVFloat8E3M4(int lanes = 1) { return DataType(kFloat8_e3m4, 8, lanes); } + static DataType Float8E3M4(int lanes = 1) { return DataType(kFloat8_e3m4, 8, lanes); } /*! - * \brief Construct NV float8 e4m3 datatype. + * \brief Construct float8 e4m3 datatype. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType NVFloat8E4M3(int lanes = 1) { return DataType(kFloat8_e4m3, 8, lanes); } + static DataType Float8E4M3(int lanes = 1) { return DataType(kFloat8_e4m3, 8, lanes); } /*! - * \brief Construct NV float8 e4m3b11fnuz datatype. + * \brief Construct float8 e4m3b11fnuz datatype. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType NVFloat8E4M3B11FNUZ(int lanes = 1) { + static DataType Float8E4M3B11FNUZ(int lanes = 1) { return DataType(kFloat8_e4m3b11fnuz, 8, lanes); } /*! - * \brief Construct NV float8 e4m3fn datatype. + * \brief Construct float8 e4m3fn datatype. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType NVFloat8E4M3FN(int lanes = 1) { return DataType(kFloat8_e4m3fn, 8, lanes); } + static DataType Float8E4M3FN(int lanes = 1) { return DataType(kFloat8_e4m3fn, 8, lanes); } /*! - * \brief Construct NV float8 e4m3fnuz datatype. + * \brief Construct float8 e4m3fnuz datatype. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType NVFloat8E4M3FNUZ(int lanes = 1) { return DataType(kFloat8_e4m3fnuz, 8, lanes); } + static DataType Float8E4M3FNUZ(int lanes = 1) { return DataType(kFloat8_e4m3fnuz, 8, lanes); } /*! - * \brief Construct NV float8 e5m2 datatype. + * \brief Construct float8 e5m2 datatype. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kFloat8_e5m2, 8, lanes); } + static DataType Float8E5M2(int lanes = 1) { return DataType(kFloat8_e5m2, 8, lanes); } /*! - * \brief Construct NV float8 e5m2fnuz datatype. + * \brief Construct float8 e5m2fnuz datatype. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType NVFloat8E5M2FNUZ(int lanes = 1) { return DataType(kFloat8_e5m2fnuz, 8, lanes); } + static DataType Float8E5M2FNUZ(int lanes = 1) { return DataType(kFloat8_e5m2fnuz, 8, lanes); } /*! - * \brief Construct NV float8 e8m0fnu datatype. + * \brief Construct float8 e8m0fnu datatype. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType NVFloat8E8M0FNU(int lanes = 1) { return DataType(kFloat8_e8m0fnu, 8, lanes); } + static DataType Float8E8M0FNU(int lanes = 1) { return DataType(kFloat8_e8m0fnu, 8, lanes); } /*! - * \brief Construct NV float6 e2m3fn datatype. + * \brief Construct float6 e2m3fn datatype. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType NVFloat6E2M3FN(int lanes = 1) { return DataType(kFloat6_e2m3fn, 6, lanes); } + static DataType Float6E2M3FN(int lanes = 1) { return DataType(kFloat6_e2m3fn, 6, lanes); } /*! - * \brief Construct NV float6 e3m2fn datatype. + * \brief Construct float6 e3m2fn datatype. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType NVFloat6E3M2FN(int lanes = 1) { return DataType(kFloat6_e3m2fn, 6, lanes); } + static DataType Float6E3M2FN(int lanes = 1) { return DataType(kFloat6_e3m2fn, 6, lanes); } /*! - * \brief Construct NV float4 e2m1fn datatype. + * \brief Construct float4 e2m1fn datatype. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType NVFloat4E2M1FN(int lanes = 1) { return DataType(kFloat4_e2m1fn, 4, lanes); } + static DataType Float4E2M1FN(int lanes = 1) { return DataType(kFloat4_e2m1fn, 4, lanes); } /*! * \brief Construct a bool type. * \param lanes The number of lanes. @@ -418,8 +418,8 @@ inline int GetVectorBytes(DataType dtype) { int data_bits = dtype.bits() * dtype.lanes(); // allow bool to exist if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) || - dtype == DataType::Int(1) || dtype == DataType::NVFloat4E2M1FN() || - dtype == DataType::NVFloat6E2M3FN() || dtype == DataType::NVFloat6E3M2FN()) { + dtype == DataType::Int(1) || dtype == DataType::Float4E2M1FN() || + dtype == DataType::Float6E2M3FN() || dtype == DataType::Float6E3M2FN()) { return 1; } ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes"; diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index cfbaac7dc2a0..b36f5cd7384d 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -504,20 +504,19 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64)); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E3M4, DataType::NVFloat8E3M4); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3, DataType::NVFloat8E4M3); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3B11FNUZ, - DataType::NVFloat8E4M3B11FNUZ); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FN, DataType::NVFloat8E4M3FN); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FNUZ, DataType::NVFloat8E4M3FNUZ); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2, DataType::NVFloat8E5M2); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2FNUZ, DataType::NVFloat8E5M2FNUZ); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E8M0FNU, DataType::NVFloat8E8M0FNU); - -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E2M3FN, DataType::NVFloat6E2M3FN); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E3M2FN, DataType::NVFloat6E3M2FN); - -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1FN, DataType::NVFloat4E2M1FN); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E3M4, DataType::Float8E3M4); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3, DataType::Float8E4M3); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3B11FNUZ, DataType::Float8E4M3B11FNUZ); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FN, DataType::Float8E4M3FN); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FNUZ, DataType::Float8E4M3FNUZ); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2, DataType::Float8E5M2); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2FNUZ, DataType::Float8E5M2FNUZ); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E8M0FNU, DataType::Float8E8M0FNU); + +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E2M3FN, DataType::Float6E2M3FN); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E3M2FN, DataType::Float6E3M2FN); + +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1FN, DataType::Float4E2M1FN); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool()); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void()); diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc index 78ba6fec34ac..8cc338b62623 100644 --- a/src/relax/op/tensor/qdq.cc +++ b/src/relax/op/tensor/qdq.cc @@ -50,8 +50,7 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); if (attrs->out_dtype != DataType::Int(8) && attrs->out_dtype != DataType::UInt(8) && attrs->out_dtype != DataType::Int(16) && attrs->out_dtype != DataType::UInt(16) && - attrs->out_dtype != DataType::NVFloat8E4M3() && - attrs->out_dtype != DataType::NVFloat8E5M2()) { + attrs->out_dtype != DataType::Float8E4M3FN() && attrs->out_dtype != DataType::Float8E5M2()) { ctx->ReportFatal(Diagnostic::Error(call) << "Unsupported output datatype attribute for operation: '" << attrs->out_dtype); @@ -145,8 +144,8 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) // Check input datatype: if (input_sinfo->dtype != DataType::Int(8) && input_sinfo->dtype != DataType::UInt(8) && input_sinfo->dtype != DataType::Int(16) && input_sinfo->dtype != DataType::UInt(16) && - input_sinfo->dtype != DataType::Int(32) && input_sinfo->dtype != DataType::NVFloat8E4M3() && - input_sinfo->dtype != DataType::NVFloat8E5M2() && input_sinfo->dtype != DataType::Float(16) && + input_sinfo->dtype != DataType::Int(32) && input_sinfo->dtype != DataType::Float8E4M3FN() && + input_sinfo->dtype != DataType::Float8E5M2() && input_sinfo->dtype != DataType::Float(16) && input_sinfo->dtype != DataType::Float(32)) { ctx->ReportFatal(Diagnostic::Error(call) << "Unsupported input datatype for operation: " << attrs->out_dtype); diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index a2f868debb47..1557d78ad0c8 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -60,10 +60,9 @@ std::string GetFP8Type(DataType type) { } stream << "__nv_fp8"; std::string suffix; - if (type.code() == DataType::kFloat8_e4m3fn || type.code() == DataType::kFloat8_e4m3fnuz || - type.code() == DataType::kFloat8_e4m3 || type.code() == DataType::kFloat8_e4m3b11fnuz) { + if (type.code() == DataType::kFloat8_e4m3fn) { suffix = "_e4m3"; - } else if (type.code() == DataType::kFloat8_e5m2 || type.code() == DataType::kFloat8_e5m2fnuz) { + } else if (type.code() == DataType::kFloat8_e5m2) { suffix = "_e5m2"; } else if (type.code() == DataType::kFloat8_e8m0fnu) { suffix = "_e8m0"; diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 16aae03932cf..7ae226c100c0 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -503,8 +503,8 @@ class Vectorizer : public StmtMutator, public ExprFunctordtype.with_scalable_vscale_factor(lanes), op->op, {value}); } else { - int new_lanes = (op->dtype != DataType::NVFloat4E2M1FN() && - op->args[0].dtype() != DataType::NVFloat4E2M1FN()) + int new_lanes = (op->dtype != DataType::Float4E2M1FN() && + op->args[0].dtype() != DataType::Float4E2M1FN()) ? (value.dtype().bits() * value.dtype().lanes()) / op->dtype.bits() : value.dtype().lanes(); return Call(op->dtype.with_lanes(new_lanes), op->op, {value}); diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index 139347a5b9d9..aa9080a48882 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -40,9 +40,7 @@ "input", [ ("float8_e4m3fn", "__nv_fp8_e4m3"), - ("float8_e4m3fnuz", "__nv_fp8_e4m3"), ("float8_e5m2", "__nv_fp8_e5m2"), - ("float8_e5m2fnuz", "__nv_fp8_e5m2"), ], ) @tvm.testing.requires_cuda_compute_version(10) @@ -90,7 +88,7 @@ def add( @pytest.mark.parametrize( "dtype", - ["float8_e4m3fn", "float8_e4m3fnuz", "float8_e5m2", "float8_e5m2fnuz", "float8_e8m0fnu"], + ["float8_e4m3fn", "float8_e5m2", "float8_e8m0fnu"], ) @tvm.testing.requires_cuda_compute_version(10) def test_fp8_packing(dtype): From 481b6c8517063f62bf6aa9cbb738ab279f81513e Mon Sep 17 00:00:00 2001 From: Kathryn-cat Date: Mon, 2 Jun 2025 16:54:24 -0400 Subject: [PATCH 3/4] fix ci error --- python/tvm/runtime/ndarray.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 78af2569d2e0..425327d36b97 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -213,7 +213,8 @@ def numpy(self): np_arr = np.empty(shape, dtype=dtype) assert np_arr.flags["C_CONTIGUOUS"] data = np_arr.ctypes.data_as(ctypes.c_void_p) - nbytes = (np_arr.size * old_dtype.bits + 7) // 8 + # TODO(kathy): revisit and get a mirrored function of ffi::GetDataSize in Python to replace line below + nbytes = np_arr.size if dtype == "bool" else (np_arr.size * old_dtype.bits + 7) // 8 _ffi_api.TVMArrayCopyToBytes(self, data, nbytes) if old_dtype == "int4" or old_dtype.startswith("float4_e2m1fn"): From 87ecbaacec767fba8a3475f5625317b10254bc3b Mon Sep 17 00:00:00 2001 From: Kathryn-cat Date: Mon, 2 Jun 2025 17:40:52 -0400 Subject: [PATCH 4/4] lint --- python/tvm/runtime/ndarray.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 425327d36b97..9d49d9c51db2 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -213,7 +213,8 @@ def numpy(self): np_arr = np.empty(shape, dtype=dtype) assert np_arr.flags["C_CONTIGUOUS"] data = np_arr.ctypes.data_as(ctypes.c_void_p) - # TODO(kathy): revisit and get a mirrored function of ffi::GetDataSize in Python to replace line below + # TODO(kathy): revisit and get a mirrored function of ffi::GetDataSize + # in Python to replace line below nbytes = np_arr.size if dtype == "bool" else (np_arr.size * old_dtype.bits + 7) // 8 _ffi_api.TVMArrayCopyToBytes(self, data, nbytes)