diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 349b7d2d546f..d001b671fc57 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -233,7 +233,12 @@ def numpy(self): if dtype == "int4": dtype = "int8" if dtype == "bfloat16": - dtype = "uint16" + if ml_dtypes is not None: + dtype = ml_dtypes.bfloat16 + else: + raise RuntimeError( + "ml_dtypes is not installed, cannot convert bfloat16 array to numpy." + ) if dtype == "float8_e4m3fn": if ml_dtypes is not None: dtype = ml_dtypes.float8_e4m3fn diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 34023e0bb7d7..a97e66d3467c 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -194,7 +194,7 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#include \n"; decl_stream << "#endif\n\n"; } - declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_, enable_fp4_); + declare_vector_type_extensions(decl_stream, enable_fp16_, enable_bf16_, enable_fp8_, enable_fp4_); if (enable_warp_shuffle_) { decl_stream << _cuda_warp_intrinsic_util; @@ -331,8 +331,12 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) if (t.is_scalar()) { os << "nv_bfloat16"; } else if (lanes <= 8) { - ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; - os << "uint" << lanes / 2; + ICHECK_EQ(lanes % 2, 0) << "only support even lane for bfloat16 type"; + if (lanes <= 4) { + os << "nv_bfloat16" << lanes; + } else { + os << "uint" << lanes / 2; + } } else { fail = true; } @@ -575,7 +579,11 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } } else if (t.is_bfloat16()) { - os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + if (t.lanes() <= 4) { + os << vec << "." << access[i]; + } else { + os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + } } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { @@ -630,8 +638,12 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i, } } else if (t.is_bfloat16()) { - stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] - << " = " << value << ";\n"; + if (t.lanes() <= 4) { + stream << vec << "." << access[i] << " = " << value << ";\n"; + } else { + stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] + << " = " << value << ";\n"; + } } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { @@ -736,9 +748,13 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { target_ty.code() == DataType::kFloat4_e2m1fn || from_ty.code() == DataType::kFloat8_e4m3fn || from_ty.code() == DataType::kFloat8_e5m2 || from_ty.code() == DataType::kFloat4_e2m1fn) { std::ostringstream val; - val << "("; - PrintType(target_ty, val); - val << ")(" << PrintExpr(op->value) << ")"; + if (target_ty.code() == DataType::kBFloat && target_ty.lanes() == 2) { + val << "cast_to_nv_bfloat162(" << PrintExpr(op->value) << ")"; + } else { + val << "("; + PrintType(target_ty, val); + val << ")(" << PrintExpr(op->value) << ")"; + } os << val.str(); return; } @@ -1384,9 +1400,16 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO std::string v = PrintExpr(op->value); PrintVecConstructor(op->dtype, os); os << '('; - for (int i = 0; i < lanes / 2; ++i) { - if (i != 0) os << ", "; - os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; + if (lanes > 4) { + for (int i = 0; i < lanes / 2; ++i) { + if (i != 0) os << ", "; + os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; + } + } else { + for (int i = 0; i < lanes; ++i) { + if (i != 0) os << ", "; + os << v; + } } os << ')'; return; @@ -1660,15 +1683,10 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val PrintVecConstructor(t, os); os << '('; } - if (i % 2 == 0) { - os << "__pack_nv_bfloat162(" << value; + if (i == t.lanes() - 1) { + os << value << ")"; } else { - os << "," << value << ")"; - if (i != t.lanes() - 1) { - os << ","; - } else { - os << ")"; - } + os << value << ","; } return; } diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index 86f2219fe8cb..b095f5b8cf20 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -385,52 +385,70 @@ static constexpr const char* _cuda_warp_intrinsic_util = R"( )"; -void declare_vector_type_extensions(std::ostringstream& stream, bool enable_fp16, bool enable_fp8, - bool enable_fp4) { - if (enable_fp16 || enable_fp8 || enable_fp4) { +void declare_vector_type_extensions(std::ostringstream& stream, bool enable_fp16, bool enable_bf16, + bool enable_fp8, bool enable_fp4) { + if (enable_fp16 || enable_bf16) { stream << R"( -struct __align__(8) half4 { - __half x, y, z, w; - __host__ __device__ half4() : x(__half(0)), y(__half(0)), z(__half(0)), w(__half(0)) {} - __host__ __device__ half4(__half x, __half y, __half z, __half w) : x(x), y(y), z(z), w(w) {} +#include +template +struct __align__(8) half4_bfloat164 { + T x, y, z, w; + __host__ __device__ half4_bfloat164() : x(T(0)), y(T(0)), z(T(0)), w(T(0)) {} + __host__ __device__ half4_bfloat164(T x, T y, T z, T w) : x(x), y(y), z(z), w(w) {} )"; if (enable_fp8) { stream << R"( - __host__ __device__ explicit half4(const __nv_fp8x4_e4m3& fp8x4) { - __nv_fp8x2_e4m3 lo_part, hi_part; - lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF); - hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF); - __half2 lo_half2 = static_cast<__half2>(lo_part); - __half2 hi_half2 = static_cast<__half2>(hi_part); - x = reinterpret_cast<__half*>(&lo_half2)[0]; - y = reinterpret_cast<__half*>(&lo_half2)[1]; - z = reinterpret_cast<__half*>(&hi_half2)[0]; - w = reinterpret_cast<__half*>(&hi_half2)[1]; + __host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e4m3& fp8x4) { + if constexpr (std::is_same_v) { + __nv_fp8x2_e4m3 lo_part, hi_part; + lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF); + hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF); + TVec2 lo_half2 = static_cast(lo_part); + TVec2 hi_half2 = static_cast(hi_part); + x = reinterpret_cast(&lo_half2)[0]; + y = reinterpret_cast(&lo_half2)[1]; + z = reinterpret_cast(&hi_half2)[0]; + w = reinterpret_cast(&hi_half2)[1]; + } else { + __nv_fp8_storage_t elem0_raw = static_cast<__nv_fp8_storage_t>(fp8x4.__x & 0xFF); + __nv_fp8_storage_t elem1_raw = static_cast<__nv_fp8_storage_t>((fp8x4.__x >> 8) & 0xFF); + __nv_fp8_storage_t elem2_raw = static_cast<__nv_fp8_storage_t>((fp8x4.__x >> 16) & 0xFF); + __nv_fp8_storage_t elem3_raw = static_cast<__nv_fp8_storage_t>((fp8x4.__x >> 24) & 0xFF); + __nv_fp8_e4m3 elem0, elem1, elem2, elem3; + elem0.__x = elem0_raw; + elem1.__x = elem1_raw; + elem2.__x = elem2_raw; + elem3.__x = elem3_raw; + x = T(elem0); + y = T(elem1); + z = T(elem2); + w = T(elem3); + } } __host__ __device__ explicit operator __nv_fp8x4_e4m3() const { __nv_fp8x4_e4m3 result; - __half2 lo_half2 = *reinterpret_cast(&x); - __half2 hi_half2 = *reinterpret_cast(&z); + TVec2 lo_half2 = *reinterpret_cast(&x); + TVec2 hi_half2 = *reinterpret_cast(&z); __nv_fp8x2_e4m3 lo_part(lo_half2), hi_part(hi_half2); result.__x = (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16)); return result; } - __host__ __device__ explicit half4(const __nv_fp8x4_e5m2& fp8x4) { - __nv_fp8x2_e5m2 lo_part, hi_part; - lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF); - hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF); - __half2 lo_half2 = static_cast<__half2>(lo_part); - __half2 hi_half2 = static_cast<__half2>(hi_part); - x = reinterpret_cast<__half*>(&lo_half2)[0]; - y = reinterpret_cast<__half*>(&lo_half2)[1]; - z = reinterpret_cast<__half*>(&hi_half2)[0]; - w = reinterpret_cast<__half*>(&hi_half2)[1]; + __host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e5m2& fp8x4) { + __nv_fp8x2_e5m2 lo_part, hi_part; + lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF); + hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF); + TVec2 lo_half2 = static_cast(lo_part); + TVec2 hi_half2 = static_cast(hi_part); + x = reinterpret_cast(&lo_half2)[0]; + y = reinterpret_cast(&lo_half2)[1]; + z = reinterpret_cast(&hi_half2)[0]; + w = reinterpret_cast(&hi_half2)[1]; } __host__ __device__ explicit operator __nv_fp8x4_e5m2() const { __nv_fp8x4_e5m2 result; - __half2 lo_half2 = *reinterpret_cast(&x); - __half2 hi_half2 = *reinterpret_cast(&z); + TVec2 lo_half2 = *reinterpret_cast(&x); + TVec2 hi_half2 = *reinterpret_cast(&z); __nv_fp8x2_e5m2 lo_part(lo_half2), hi_part(hi_half2); result.__x = (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16)); @@ -460,31 +478,70 @@ struct __align__(8) half4 { } if (enable_fp4) { stream << R"( - __host__ __device__ explicit half4(const __nv_fp4x4_e2m1& fp4x4) { - __nv_fp4x2_storage_t lo_part, hi_part; - lo_part = static_cast<__nv_fp4x2_storage_t>(fp4x4.__x & 0xFF); - hi_part = static_cast<__nv_fp4x2_storage_t>((fp4x4.__x >> 8) & 0xFF); - __half2 lo_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(lo_part, __NV_E2M1)); - __half2 hi_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(hi_part, __NV_E2M1)); - x = reinterpret_cast<__half*>(&lo_half2)[0]; - y = reinterpret_cast<__half*>(&lo_half2)[1]; - z = reinterpret_cast<__half*>(&hi_half2)[0]; - w = reinterpret_cast<__half*>(&hi_half2)[1]; + __host__ __device__ explicit half4_bfloat164(const __nv_fp4x4_e2m1& fp4x4) { + if constexpr (std::is_same_v) { + __nv_fp4x2_storage_t lo_part = static_cast<__nv_fp4x2_storage_t>(fp4x4.__x & 0xFF); + __nv_fp4x2_storage_t hi_part = static_cast<__nv_fp4x2_storage_t>((fp4x4.__x >> 8) & 0xFF); + TVec2 lo_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(lo_part, __NV_E2M1)); + TVec2 hi_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(hi_part, __NV_E2M1)); + x = reinterpret_cast(&lo_half2)[0]; + y = reinterpret_cast(&lo_half2)[1]; + z = reinterpret_cast(&hi_half2)[0]; + w = reinterpret_cast(&hi_half2)[1]; + } else { + __nv_fp4_e2m1 elem0, elem1, elem2, elem3; + elem0.__x = static_cast<__nv_fp4_storage_t>(fp4x4.__x & 0xF); + elem1.__x = static_cast<__nv_fp4_storage_t>((fp4x4.__x >> 4) & 0xF); + elem2.__x = static_cast<__nv_fp4_storage_t>((fp4x4.__x >> 8) & 0xF); + elem3.__x = static_cast<__nv_fp4_storage_t>((fp4x4.__x >> 12) & 0xF); + x = T(elem0); + y = T(elem1); + z = T(elem2); + w = T(elem3); + } } __host__ __device__ explicit operator __nv_fp4x4_e2m1() const { - __half2 lo_half2 = *reinterpret_cast(&x); - __half2 hi_half2 = *reinterpret_cast(&z); + TVec2 lo_half2 = *reinterpret_cast(&x); + TVec2 hi_half2 = *reinterpret_cast(&z); return __nv_fp4x4_e2m1(lo_half2, hi_half2); } )"; } stream << R"( }; +)"; + } + if (enable_fp16) { + stream << R"( +using half4 = half4_bfloat164<__half, __half2>; __host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) { return half4(x, y, z, w); } )"; } + if (enable_bf16) { + stream << R"( +using nv_bfloat164 = half4_bfloat164; +__host__ __device__ nv_bfloat164 make_nv_bfloat164(nv_bfloat16 x, nv_bfloat16 y, nv_bfloat16 z, nv_bfloat16 w) { + return nv_bfloat164(x, y, z, w); +} +__host__ __device__ nv_bfloat162 make_nv_bfloat162(nv_bfloat16 x, nv_bfloat16 y) { + return nv_bfloat162(x, y); +} +)"; + if (enable_fp8) { + stream << R"( +__host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp8x2_e4m3& fp8x2) { + __nv_fp8_e4m3 elem0, elem1; + elem0.__x = static_cast<__nv_fp8_storage_t>(fp8x2.__x & 0xFF); + elem1.__x = static_cast<__nv_fp8_storage_t>((fp8x2.__x >> 8) & 0xFF); + nv_bfloat16 x = nv_bfloat16(elem0); + nv_bfloat16 y = nv_bfloat16(elem1); + return nv_bfloat162(x, y); +} +)"; + } + } if (enable_fp4) { stream << R"( __device__ __nv_fp4x2_e2m1 make___nv_fp4x2_e2m1(__nv_fp4_e2m1 x, __nv_fp4_e2m1 y) { @@ -497,6 +554,14 @@ __device__ __nv_fp4x4_e2m1 make___nv_fp4x4_e2m1(__nv_fp4_e2m1 a, __nv_fp4_e2m1 b result.__x = (static_cast<__nv_fp4x4_storage_t>(a.__x)) | (static_cast<__nv_fp4x4_storage_t>(b.__x) << 4) | (static_cast<__nv_fp4x4_storage_t>(c.__x) << 8) | (static_cast<__nv_fp4x4_storage_t>(d.__x) << 12); return result; } +__host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp4x2_e2m1& fp4x2) { + __nv_fp4_e2m1 elem0, elem1; + elem0.__x = static_cast<__nv_fp4_storage_t>(fp4x2.__x & 0xF); + elem1.__x = static_cast<__nv_fp4_storage_t>((fp4x2.__x >> 4) & 0xF); + nv_bfloat16 x = nv_bfloat16(elem0); + nv_bfloat16 y = nv_bfloat16(elem1); + return nv_bfloat162(x, y); +} )"; } } diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index 5e0a4c3000da..7b3a20463ba9 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -import sys +from itertools import product from typing import List, Tuple import numpy as np @@ -26,13 +26,9 @@ from tvm import DataType, DataTypeCode, IRModule from tvm import dlight as dl from tvm import relax, te, tir, topi -from tvm.relax.frontend import nn -from tvm.runtime import NDArray from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tir as T -from tvm.target import Target -from tvm.topi.utils import get_const_tuple try: import ml_dtypes @@ -67,7 +63,7 @@ def add( sch.bind(tx, "threadIdx.x") target = "cuda" - fadd = tvm.compile(sch.mod, target=target) + fadd = tvm.tir.build(sch.mod, target=target) cuda_src = fadd.imported_modules[0].get_source() assert "__nv_fp8_e4m3" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found in generated CUDA" @@ -179,7 +175,7 @@ def add( sch.bind(tx, "threadIdx.x") target = "cuda" - fadd = tvm.compile(sch.mod, target=target) + fadd = tvm.tir.build(sch.mod, target=target) cuda_src = fadd.imported_modules[0].get_source() dev = tvm.device(target, 0) @@ -700,7 +696,7 @@ def compile_quant_and_dequant_by_scale( def print_cuda(target, mod, name=None): if name: mod = mod[name] - f = tvm.compile(mod, target=target) + f = tvm.tir.build(mod, target=target) cuda_src = f.imported_modules[0].get_source() print(cuda_src) @@ -963,6 +959,41 @@ def _pipeline(mod: tvm.ir.IRModule) -> tvm.ir.IRModule: vm["main"](x, indptr, weight, scale) +@tvm.testing.requires_cuda_compute_version(8, 9) +def test_fp8_fp16_bf16_vectorize_arith(): + for vec_length, dtype in product([2, 4], ["float16", "bfloat16"]): + + @T.prim_func + def func_vectorize( + A: T.Buffer((128,), "float8_e4m3fn"), + B: T.Buffer((128,), dtype), + C: T.Buffer((128,), dtype), + ) -> None: + for i in T.serial(128): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + C[vi] = (A[vi].astype(dtype) * B[vi]) + T.bfloat16(3.0) + + sch = tir.Schedule(func_vectorize) + (l,) = sch.get_loops(sch.get_block("compute")) + lo, li = sch.split(l, [None, vec_length]) + sch.bind(lo, "threadIdx.x") + sch.vectorize(li) + + device = tvm.cuda() + target = tvm.target.Target.from_device(device) + f = tir.build(sch.mod, target=target) + + a_np = np.random.rand(128).astype("float8_e4m3fn") + b_np = np.random.rand(128).astype(dtype) + c_np = (a_np.astype(dtype) * b_np) + 3 + a_tvm = tvm.nd.array(a_np, device=device) + b_tvm = tvm.nd.array(b_np, device=device) + c_tvm = tvm.nd.empty((128,), dtype=dtype, device=device) + f(a_tvm, b_tvm, c_tvm) + c_tvm = c_tvm.numpy() + np.testing.assert_allclose(c_tvm, c_np, atol=1e-3, rtol=1e-3) + + if __name__ == "__main__": - # test_half_broadcast(6) tvm.testing.main() diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index e4a28453184d..cea0bc2d9318 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -758,33 +758,6 @@ def check_llvm_ir(): check_llvm_ir() -def np_float2np_bf16(arr): - """Convert a numpy array of float to a numpy array - of bf16 in uint16""" - orig = arr.view("