16
16
17
17
// ================================================================================
18
18
// this file has been auto-generated, do not modify its contents!
19
- // date: 2025-08-21 10:13:04.148230
20
- // git hash: 4d0d49cad7962d3f9ba4f2a0abfa2faea3ec7efa
19
+ // date: 2025-09-02 18:31:16.281730
20
+ // git hash: 023bc75e8ec67145cdcb447c5fd9aa7d7f180cc6
21
21
// ================================================================================
22
22
23
23
#ifndef KERNEL_FLOAT_MACROS_H
59
59
#define KERNEL_FLOAT_FP16_AVAILABLE (1 )
60
60
#endif // KERNEL_FLOAT_FP16_AVAILABLE
61
61
62
+ #ifndef KERNEL_FLOAT_FP16_OPS_AVAILABLE
63
+ #define KERNEL_FLOAT_FP16_OPS_AVAILABLE ((KERNEL_FLOAT_IS_CUDA && __CUDA_ARCH__ >= 530 ) || KERNEL_FLOAT_IS_HIP)
64
+ #endif
65
+
62
66
#ifndef KERNEL_FLOAT_BF16_AVAILABLE
63
67
#define KERNEL_FLOAT_BF16_AVAILABLE (1 )
64
68
#endif // KERNEL_FLOAT_BF16_AVAILABLE
65
69
70
+ #ifndef KERNEL_FLOAT_BF16_OPS_AVAILABLE
71
+ #define KERNEL_FLOAT_BF16_OPS_AVAILABLE ((KERNEL_FLOAT_IS_CUDA && __CUDA_ARCH__ >= 800 ) || KERNEL_FLOAT_IS_HIP)
72
+ #endif
73
+
66
74
#ifndef KERNEL_FLOAT_FP8_AVAILABLE
67
75
#ifdef __CUDACC_VER_MAJOR__
68
76
#define KERNEL_FLOAT_FP8_AVAILABLE (__CUDACC_VER_MAJOR__ >= 12 )
@@ -4171,6 +4179,7 @@ struct allow_float_fallback<half_t> {
4171
4179
#define KERNEL_FLOAT_FP16_UNARY_FUN (NAME, FUN1, FUN2 )
4172
4180
#endif
4173
4181
4182
+ #if KERNEL_FLOAT_FP16_OPS_AVAILABLE
4174
4183
KERNEL_FLOAT_FP16_UNARY_FUN (sin, hsin, h2sin)
4175
4184
KERNEL_FLOAT_FP16_UNARY_FUN(cos, hcos, h2cos)
4176
4185
@@ -4191,6 +4200,7 @@ KERNEL_FLOAT_FP16_UNARY_FUN(ceil, hceil, h2ceil)
4191
4200
KERNEL_FLOAT_FP16_UNARY_FUN(rint, hrint, h2rint)
4192
4201
KERNEL_FLOAT_FP16_UNARY_FUN(trunc, htrunc, h2trunc)
4193
4202
KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2)
4203
+ #endif // KERNEL_FLOAT_FP16_OPS_AVAILABLE
4194
4204
4195
4205
#if KERNEL_FLOAT_IS_DEVICE
4196
4206
#define KERNEL_FLOAT_FP16_BINARY_FUN (NAME, FUN1, FUN2 ) \
@@ -4217,10 +4227,11 @@ KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2)
4217
4227
#endif
4218
4228
4219
4229
// There are not available in HIP
4230
+ #if KERNEL_FLOAT_FP16_OPS_AVAILABLE
4220
4231
#if KERNEL_FLOAT_IS_CUDA
4221
4232
KERNEL_FLOAT_FP16_BINARY_FUN (min, __hmin, __hmin2)
4222
4233
KERNEL_FLOAT_FP16_BINARY_FUN(max, __hmax, __hmax2)
4223
- #endif
4234
+ #endif // KERNEL_FLOAT_IS_CUDA
4224
4235
4225
4236
KERNEL_FLOAT_FP16_BINARY_FUN (add, __hadd, __hadd2)
4226
4237
KERNEL_FLOAT_FP16_BINARY_FUN(subtract, __hsub, __hsub2)
@@ -4233,7 +4244,9 @@ KERNEL_FLOAT_FP16_BINARY_FUN(less, __hlt, __hlt2)
4233
4244
KERNEL_FLOAT_FP16_BINARY_FUN(less_equal, __hle, __hle2)
4234
4245
KERNEL_FLOAT_FP16_BINARY_FUN(greater, __hgt, __hgt2)
4235
4246
KERNEL_FLOAT_FP16_BINARY_FUN(greater_equal, __hge, __hgt2)
4247
+ #endif // KERNEL_FLOAT_FP16_OPS_AVAILABLE
4236
4248
4249
+ #if KERNEL_FLOAT_FP16_OPS_AVAILABLE
4237
4250
#if KERNEL_FLOAT_IS_DEVICE
4238
4251
namespace ops {
4239
4252
template <>
@@ -4270,7 +4283,8 @@ struct apply_impl<accurate_policy, ops::fma<half_t>, 2, half_t, half_t, half_t,
4270
4283
4271
4284
KERNEL_FLOAT_FAST_F32_MAP (KERNEL_FLOAT_FAST_FP16_DISPATCH)
4272
4285
} // namespace detail
4273
- #endif
4286
+ #endif // KERNEL_FLOAT_IS_DEVICE
4287
+ #endif // KERNEL_FLOAT_FP16_OPS_AVAILABLE
4274
4288
4275
4289
#define KERNEL_FLOAT_FP16_CAST (T, TO_HALF, FROM_HALF ) \
4276
4290
namespace ops { \
@@ -4335,7 +4349,7 @@ KERNEL_FLOAT_VECTOR_ALIAS(half, half_t)
4335
4349
4336
4350
} // namespace kernel_float
4337
4351
4338
- #endif
4352
+ #endif // KERNEL_FLOAT_FP16_AVAILABLE
4339
4353
4340
4354
#endif // KERNEL_FLOAT_FP16_H
4341
4355
#ifndef KERNEL_FLOAT_BF16_H
@@ -4369,10 +4383,6 @@ using bfloat16_t = __hip_bfloat16;
4369
4383
using bfloat16x2_t = __hip_bfloat162;
4370
4384
#endif
4371
4385
4372
- #if KERNEL_FLOAT_IS_CUDA && __CUDA_ARCH__ >= 800
4373
- #define KERNEL_FLOAT_BF16_OPS_SUPPORTED 1
4374
- #endif
4375
-
4376
4386
template <>
4377
4387
struct preferred_vector_size <bfloat16_t > {
4378
4388
static constexpr size_t value = 2 ;
@@ -4420,7 +4430,7 @@ struct allow_float_fallback<bfloat16_t> {
4420
4430
}; \
4421
4431
}
4422
4432
4423
- #if KERNEL_FLOAT_BF16_OPS_SUPPORTED
4433
+ #if KERNEL_FLOAT_BF16_OPS_AVAILABLE
4424
4434
KERNEL_FLOAT_BF16_UNARY_FUN (sin, ::hsin, ::h2sin)
4425
4435
KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos)
4426
4436
@@ -4496,7 +4506,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, hip_hneg, hip_hneg2)
4496
4506
}; \
4497
4507
}
4498
4508
4499
- #if KERNEL_FLOAT_BF16_OPS_SUPPORTED
4509
+ #if KERNEL_FLOAT_BF16_OPS_AVAILABLE
4500
4510
KERNEL_FLOAT_BF16_BINARY_FUN (add, __hadd, __hadd2)
4501
4511
KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2)
4502
4512
KERNEL_FLOAT_BF16_BINARY_FUN(multiply, __hmul, __hmul2)
@@ -4512,7 +4522,7 @@ KERNEL_FLOAT_BF16_BINARY_FUN(greater, __hgt, __hgt2)
4512
4522
KERNEL_FLOAT_BF16_BINARY_FUN(greater_equal, __hge, __hgt2)
4513
4523
#endif
4514
4524
4515
- #if KERNEL_FLOAT_BF16_OPS_SUPPORTED
4525
+ #if KERNEL_FLOAT_BF16_OPS_AVAILABLE
4516
4526
namespace ops {
4517
4527
template <>
4518
4528
struct fma <bfloat16_t > {
@@ -4583,7 +4593,7 @@ KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_BF16_DISPATCH)
4583
4593
KERNEL_FLOAT_BF16_CAST (float , __float2bfloat16(input), __bfloat162float(input))
4584
4594
KERNEL_FLOAT_BF16_CAST (double , __double2bfloat16(input), __bfloat162float(input))
4585
4595
4586
- #if KERNEL_FLOAT_BF16_OPS_SUPPORTED
4596
+ #if KERNEL_FLOAT_BF16_OPS_AVAILABLE
4587
4597
// clang-format off
4588
4598
// there are no official char casts. Instead, cast to int and then to char
4589
4599
KERNEL_FLOAT_BF16_CAST (char , __int2bfloat16_rn(input), (char )__bfloat162int_rz(input));
@@ -4637,7 +4647,7 @@ struct promote_type<half_t, bfloat16_t> {
4637
4647
} // namespace kernel_float
4638
4648
4639
4649
#endif // KERNEL_FLOAT_FP16_AVAILABLE
4640
- #endif
4650
+ #endif // KERNEL_FLOAT_BF16_AVAILABLE
4641
4651
4642
4652
#endif // KERNEL_FLOAT_BF16_H
4643
4653
#pragma once
@@ -4728,7 +4738,7 @@ KERNEL_FLOAT_DEFINE_POLY(asin_poly, 3, 0.05167, -0.2057, 1.57)
4728
4738
KERNEL_FLOAT_DEFINE_POLY (asin_poly, 4 , -0.02103 , 0.077 , -0.2129 , 1.57 )
4729
4739
KERNEL_FLOAT_DEFINE_POLY (asin_poly, 5 , 0.009796 , -0.03772 , 0.0857 , -0.2142 , 1.57 )
4730
4740
4731
- #if KERNEL_FLOAT_FP16_AVAILABLE
4741
+ #if KERNEL_FLOAT_FP16_OPS_AVAILABLE
4732
4742
KERNEL_FLOAT_DEVICE half2_t flipsign (half2_t input, half2_t sign) {
4733
4743
// Flip signbit of input when sign<0
4734
4744
uint32_t result;
@@ -4923,9 +4933,9 @@ KERNEL_FLOAT_DEVICE half2_t tanh(half2_t x) {
4923
4933
}
4924
4934
}
4925
4935
4926
- #endif // KERNEL_FLOAT_FP16_AVAILABLE
4936
+ #endif // KERNEL_FLOAT_FP16_OPS_AVAILABLE
4927
4937
4928
- #if KERNEL_FLOAT_BF16_OPS_SUPPORTED
4938
+ #if KERNEL_FLOAT_BF16_OPS_AVAILABLE
4929
4939
KERNEL_FLOAT_DEVICE bfloat16x2_t make_bfloat162 (bfloat16_t x) {
4930
4940
return {x, x};
4931
4941
}
@@ -5005,7 +5015,7 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) {
5005
5015
transmute<bfloat16_t >(uint16_t (transmute<uint32_t >(a))),
5006
5016
transmute<bfloat16_t >(uint16_t (transmute<uint32_t >(b)))};
5007
5017
}
5008
- #endif
5018
+ #endif // KERNEL_FLOAT_BF16_OPS_AVAILABLE
5009
5019
} // namespace approx
5010
5020
5011
5021
namespace detail {
@@ -5036,7 +5046,7 @@ struct apply_impl<approx_level_policy<Level>, F, 1, T, T> {
5036
5046
apply_impl<approx_level_policy<DEFAULT_LEVEL>, ops::FUN<T>, 2 , T, T> {}; \
5037
5047
}
5038
5048
5039
- #if KERNEL_FLOAT_FP16_AVAILABLE
5049
+ #if KERNEL_FLOAT_FP16_OPS_AVAILABLE
5040
5050
KERNEL_FLOAT_DEFINE_APPROX_IMPL (half_t , sin, 4 )
5041
5051
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , cos, 4 )
5042
5052
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , rsqrt, 1 )
@@ -5048,7 +5058,7 @@ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, asin, 2)
5048
5058
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , acos, 2 )
5049
5059
#endif
5050
5060
5051
- #if KERNEL_FLOAT_BF16_OPS_SUPPORTED
5061
+ #if KERNEL_FLOAT_BF16_OPS_AVAILABLE
5052
5062
KERNEL_FLOAT_DEFINE_APPROX_IMPL (bfloat16_t , cos, 4 )
5053
5063
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t , sin, 4 )
5054
5064
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t , rcp, 1 )
0 commit comments