Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,12 @@ def numpy(self):
if dtype == "int4":
dtype = "int8"
if dtype == "bfloat16":
dtype = "uint16"
if ml_dtypes is not None:
dtype = ml_dtypes.bfloat16
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert bfloat16 array to numpy."
)
if dtype == "float8_e4m3fn":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e4m3fn
Expand Down
58 changes: 38 additions & 20 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#include <cuda_fp4.h>\n";
decl_stream << "#endif\n\n";
}
declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_, enable_fp4_);
declare_vector_type_extensions(decl_stream, enable_fp16_, enable_bf16_, enable_fp8_, enable_fp4_);

if (enable_warp_shuffle_) {
decl_stream << _cuda_warp_intrinsic_util;
Expand Down Expand Up @@ -331,8 +331,12 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
if (t.is_scalar()) {
os << "nv_bfloat16";
} else if (lanes <= 8) {
ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
os << "uint" << lanes / 2;
ICHECK_EQ(lanes % 2, 0) << "only support even lane for bfloat16 type";
if (lanes <= 4) {
os << "nv_bfloat16" << lanes;
} else {
os << "uint" << lanes / 2;
}
} else {
fail = true;
}
Expand Down Expand Up @@ -575,7 +579,11 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
}
} else if (t.is_bfloat16()) {
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
if (t.lanes() <= 4) {
os << vec << "." << access[i];
} else {
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
}
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
Expand Down Expand Up @@ -630,8 +638,12 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
}

} else if (t.is_bfloat16()) {
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]
<< " = " << value << ";\n";
if (t.lanes() <= 4) {
stream << vec << "." << access[i] << " = " << value << ";\n";
} else {
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]
<< " = " << value << ";\n";
}
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
Expand Down Expand Up @@ -736,9 +748,13 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
target_ty.code() == DataType::kFloat4_e2m1fn || from_ty.code() == DataType::kFloat8_e4m3fn ||
from_ty.code() == DataType::kFloat8_e5m2 || from_ty.code() == DataType::kFloat4_e2m1fn) {
std::ostringstream val;
val << "(";
PrintType(target_ty, val);
val << ")(" << PrintExpr(op->value) << ")";
if (target_ty.code() == DataType::kBFloat && target_ty.lanes() == 2) {
val << "cast_to_nv_bfloat162(" << PrintExpr(op->value) << ")";
} else {
val << "(";
PrintType(target_ty, val);
val << ")(" << PrintExpr(op->value) << ")";
}
os << val.str();
return;
}
Expand Down Expand Up @@ -1384,9 +1400,16 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO
std::string v = PrintExpr(op->value);
PrintVecConstructor(op->dtype, os);
os << '(';
for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
if (lanes > 4) {
for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
}
} else {
for (int i = 0; i < lanes; ++i) {
if (i != 0) os << ", ";
os << v;
}
}
os << ')';
return;
Expand Down Expand Up @@ -1660,15 +1683,10 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val
PrintVecConstructor(t, os);
os << '(';
}
if (i % 2 == 0) {
os << "__pack_nv_bfloat162(" << value;
if (i == t.lanes() - 1) {
os << value << ")";
} else {
os << "," << value << ")";
if (i != t.lanes() - 1) {
os << ",";
} else {
os << ")";
}
os << value << ",";
}
return;
}
Expand Down
151 changes: 108 additions & 43 deletions src/target/source/literal/cuda_half_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,52 +385,70 @@ static constexpr const char* _cuda_warp_intrinsic_util = R"(

)";

void declare_vector_type_extensions(std::ostringstream& stream, bool enable_fp16, bool enable_fp8,
bool enable_fp4) {
if (enable_fp16 || enable_fp8 || enable_fp4) {
void declare_vector_type_extensions(std::ostringstream& stream, bool enable_fp16, bool enable_bf16,
bool enable_fp8, bool enable_fp4) {
if (enable_fp16 || enable_bf16) {
stream << R"(
struct __align__(8) half4 {
__half x, y, z, w;
__host__ __device__ half4() : x(__half(0)), y(__half(0)), z(__half(0)), w(__half(0)) {}
__host__ __device__ half4(__half x, __half y, __half z, __half w) : x(x), y(y), z(z), w(w) {}
#include <type_traits>
template <typename T, typename TVec2>
struct __align__(8) half4_bfloat164 {
T x, y, z, w;
__host__ __device__ half4_bfloat164() : x(T(0)), y(T(0)), z(T(0)), w(T(0)) {}
__host__ __device__ half4_bfloat164(T x, T y, T z, T w) : x(x), y(y), z(z), w(w) {}
)";
if (enable_fp8) {
stream << R"(
__host__ __device__ explicit half4(const __nv_fp8x4_e4m3& fp8x4) {
__nv_fp8x2_e4m3 lo_part, hi_part;
lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF);
__half2 lo_half2 = static_cast<__half2>(lo_part);
__half2 hi_half2 = static_cast<__half2>(hi_part);
x = reinterpret_cast<__half*>(&lo_half2)[0];
y = reinterpret_cast<__half*>(&lo_half2)[1];
z = reinterpret_cast<__half*>(&hi_half2)[0];
w = reinterpret_cast<__half*>(&hi_half2)[1];
__host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e4m3& fp8x4) {
if constexpr (std::is_same_v<T, __half>) {
__nv_fp8x2_e4m3 lo_part, hi_part;
lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF);
TVec2 lo_half2 = static_cast<TVec2>(lo_part);
TVec2 hi_half2 = static_cast<TVec2>(hi_part);
x = reinterpret_cast<T*>(&lo_half2)[0];
y = reinterpret_cast<T*>(&lo_half2)[1];
z = reinterpret_cast<T*>(&hi_half2)[0];
w = reinterpret_cast<T*>(&hi_half2)[1];
} else {
__nv_fp8_storage_t elem0_raw = static_cast<__nv_fp8_storage_t>(fp8x4.__x & 0xFF);
__nv_fp8_storage_t elem1_raw = static_cast<__nv_fp8_storage_t>((fp8x4.__x >> 8) & 0xFF);
__nv_fp8_storage_t elem2_raw = static_cast<__nv_fp8_storage_t>((fp8x4.__x >> 16) & 0xFF);
__nv_fp8_storage_t elem3_raw = static_cast<__nv_fp8_storage_t>((fp8x4.__x >> 24) & 0xFF);
__nv_fp8_e4m3 elem0, elem1, elem2, elem3;
elem0.__x = elem0_raw;
elem1.__x = elem1_raw;
elem2.__x = elem2_raw;
elem3.__x = elem3_raw;
x = T(elem0);
y = T(elem1);
z = T(elem2);
w = T(elem3);
}
}
__host__ __device__ explicit operator __nv_fp8x4_e4m3() const {
__nv_fp8x4_e4m3 result;
__half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
__half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
TVec2 lo_half2 = *reinterpret_cast<const TVec2*>(&x);
TVec2 hi_half2 = *reinterpret_cast<const TVec2*>(&z);
__nv_fp8x2_e4m3 lo_part(lo_half2), hi_part(hi_half2);
result.__x =
(static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16));
return result;
}
__host__ __device__ explicit half4(const __nv_fp8x4_e5m2& fp8x4) {
__nv_fp8x2_e5m2 lo_part, hi_part;
lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF);
__half2 lo_half2 = static_cast<__half2>(lo_part);
__half2 hi_half2 = static_cast<__half2>(hi_part);
x = reinterpret_cast<__half*>(&lo_half2)[0];
y = reinterpret_cast<__half*>(&lo_half2)[1];
z = reinterpret_cast<__half*>(&hi_half2)[0];
w = reinterpret_cast<__half*>(&hi_half2)[1];
__host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e5m2& fp8x4) {
__nv_fp8x2_e5m2 lo_part, hi_part;
lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF);
TVec2 lo_half2 = static_cast<TVec2>(lo_part);
TVec2 hi_half2 = static_cast<TVec2>(hi_part);
x = reinterpret_cast<T*>(&lo_half2)[0];
y = reinterpret_cast<T*>(&lo_half2)[1];
z = reinterpret_cast<T*>(&hi_half2)[0];
w = reinterpret_cast<T*>(&hi_half2)[1];
}
__host__ __device__ explicit operator __nv_fp8x4_e5m2() const {
__nv_fp8x4_e5m2 result;
__half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
__half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
TVec2 lo_half2 = *reinterpret_cast<const TVec2*>(&x);
TVec2 hi_half2 = *reinterpret_cast<const TVec2*>(&z);
__nv_fp8x2_e5m2 lo_part(lo_half2), hi_part(hi_half2);
result.__x =
(static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16));
Expand Down Expand Up @@ -460,31 +478,70 @@ struct __align__(8) half4 {
}
if (enable_fp4) {
stream << R"(
__host__ __device__ explicit half4(const __nv_fp4x4_e2m1& fp4x4) {
__nv_fp4x2_storage_t lo_part, hi_part;
lo_part = static_cast<__nv_fp4x2_storage_t>(fp4x4.__x & 0xFF);
hi_part = static_cast<__nv_fp4x2_storage_t>((fp4x4.__x >> 8) & 0xFF);
__half2 lo_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(lo_part, __NV_E2M1));
__half2 hi_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(hi_part, __NV_E2M1));
x = reinterpret_cast<__half*>(&lo_half2)[0];
y = reinterpret_cast<__half*>(&lo_half2)[1];
z = reinterpret_cast<__half*>(&hi_half2)[0];
w = reinterpret_cast<__half*>(&hi_half2)[1];
__host__ __device__ explicit half4_bfloat164(const __nv_fp4x4_e2m1& fp4x4) {
if constexpr (std::is_same_v<T, __half>) {
__nv_fp4x2_storage_t lo_part = static_cast<__nv_fp4x2_storage_t>(fp4x4.__x & 0xFF);
__nv_fp4x2_storage_t hi_part = static_cast<__nv_fp4x2_storage_t>((fp4x4.__x >> 8) & 0xFF);
TVec2 lo_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(lo_part, __NV_E2M1));
TVec2 hi_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(hi_part, __NV_E2M1));
x = reinterpret_cast<T*>(&lo_half2)[0];
y = reinterpret_cast<T*>(&lo_half2)[1];
z = reinterpret_cast<T*>(&hi_half2)[0];
w = reinterpret_cast<T*>(&hi_half2)[1];
} else {
__nv_fp4_e2m1 elem0, elem1, elem2, elem3;
elem0.__x = static_cast<__nv_fp4_storage_t>(fp4x4.__x & 0xF);
elem1.__x = static_cast<__nv_fp4_storage_t>((fp4x4.__x >> 4) & 0xF);
elem2.__x = static_cast<__nv_fp4_storage_t>((fp4x4.__x >> 8) & 0xF);
elem3.__x = static_cast<__nv_fp4_storage_t>((fp4x4.__x >> 12) & 0xF);
x = T(elem0);
y = T(elem1);
z = T(elem2);
w = T(elem3);
}
}
__host__ __device__ explicit operator __nv_fp4x4_e2m1() const {
__half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
__half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
TVec2 lo_half2 = *reinterpret_cast<const TVec2*>(&x);
TVec2 hi_half2 = *reinterpret_cast<const TVec2*>(&z);
return __nv_fp4x4_e2m1(lo_half2, hi_half2);
}
)";
}
stream << R"(
};
)";
}
if (enable_fp16) {
stream << R"(
using half4 = half4_bfloat164<__half, __half2>;
__host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) {
return half4(x, y, z, w);
}
)";
}
if (enable_bf16) {
stream << R"(
using nv_bfloat164 = half4_bfloat164<nv_bfloat16, nv_bfloat162>;
__host__ __device__ nv_bfloat164 make_nv_bfloat164(nv_bfloat16 x, nv_bfloat16 y, nv_bfloat16 z, nv_bfloat16 w) {
return nv_bfloat164(x, y, z, w);
}
__host__ __device__ nv_bfloat162 make_nv_bfloat162(nv_bfloat16 x, nv_bfloat16 y) {
return nv_bfloat162(x, y);
}
)";
if (enable_fp8) {
stream << R"(
__host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp8x2_e4m3& fp8x2) {
__nv_fp8_e4m3 elem0, elem1;
elem0.__x = static_cast<__nv_fp8_storage_t>(fp8x2.__x & 0xFF);
elem1.__x = static_cast<__nv_fp8_storage_t>((fp8x2.__x >> 8) & 0xFF);
nv_bfloat16 x = nv_bfloat16(elem0);
nv_bfloat16 y = nv_bfloat16(elem1);
return nv_bfloat162(x, y);
}
)";
}
}
if (enable_fp4) {
stream << R"(
__device__ __nv_fp4x2_e2m1 make___nv_fp4x2_e2m1(__nv_fp4_e2m1 x, __nv_fp4_e2m1 y) {
Expand All @@ -497,6 +554,14 @@ __device__ __nv_fp4x4_e2m1 make___nv_fp4x4_e2m1(__nv_fp4_e2m1 a, __nv_fp4_e2m1 b
result.__x = (static_cast<__nv_fp4x4_storage_t>(a.__x)) | (static_cast<__nv_fp4x4_storage_t>(b.__x) << 4) | (static_cast<__nv_fp4x4_storage_t>(c.__x) << 8) | (static_cast<__nv_fp4x4_storage_t>(d.__x) << 12);
return result;
}
__host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp4x2_e2m1& fp4x2) {
__nv_fp4_e2m1 elem0, elem1;
elem0.__x = static_cast<__nv_fp4_storage_t>(fp4x2.__x & 0xF);
elem1.__x = static_cast<__nv_fp4_storage_t>((fp4x2.__x >> 4) & 0xF);
nv_bfloat16 x = nv_bfloat16(elem0);
nv_bfloat16 y = nv_bfloat16(elem1);
return nv_bfloat162(x, y);
}
)";
}
}
Expand Down
Loading
Loading