|
2 | 2 |
|
3 | 3 | #include <ATen/ATen.h>
|
4 | 4 | #include <ATen/cuda/CUDAContext.h>
|
5 |
| -#include <ATen/cuda/NumericLimits.cuh> |
6 | 5 |
|
7 | 6 | #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
8 | 7 | #include <cuda.h>
|
9 | 8 | #include <cuda_bf16.h>
|
10 | 9 | #include <cuda_runtime.h>
|
| 10 | +#include <ATen/test/cuda_lowfp_test.cuh> |
11 | 11 |
|
12 | 12 | #include <assert.h>
|
13 | 13 |
|
14 | 14 | using namespace at;
|
15 | 15 |
|
16 |
| -__device__ void test(){ |
17 |
| - // test bfloat16 construction and implicit conversions in device |
18 |
| - assert(BFloat16(3) == BFloat16(3.0f)); |
19 |
| - assert(static_cast<BFloat16>(3.0f) == BFloat16(3.0f)); |
20 |
| - // there is no float <=> __nv_bfloat16 implicit conversion |
21 |
| - assert(static_cast<BFloat16>(3.0f) == 3.0f); |
| 16 | +__global__ void kernel(){ |
| 17 | + test<BFloat16>(); |
22 | 18 |
|
23 | 19 | __nv_bfloat16 a = __float2bfloat16(3.0f);
|
24 | 20 | __nv_bfloat16 b = __float2bfloat16(2.0f);
|
25 | 21 | __nv_bfloat16 c = a - BFloat16(b);
|
26 | 22 | assert(static_cast<BFloat16>(c) == BFloat16(1.0));
|
27 |
| - |
28 |
| - // asserting if the functions used on |
29 |
| - // bfloat16 types give almost equivalent results when using |
30 |
| - // functions on double. |
31 |
| - // The purpose of these asserts are to test the device side |
32 |
| - // bfloat16 API for the common mathematical functions. |
33 |
| - // Note: When calling std math functions from device, don't |
34 |
| - // use the std namespace, but just "::" so that the function |
35 |
| - // gets resolved from nvcc math_functions.hpp |
36 |
| - |
37 |
| - float threshold = 0.00001; |
38 |
| - assert(::abs(::lgamma(BFloat16(10.0)) - ::lgamma(10.0f)) <= threshold); |
39 |
| - assert(::abs(::exp(BFloat16(1.0)) - ::exp(1.0f)) <= threshold); |
40 |
| - assert(::abs(::log(BFloat16(1.0)) - ::log(1.0f)) <= threshold); |
41 |
| - assert(::abs(::log10(BFloat16(1000.0)) - ::log10(1000.0f)) <= threshold); |
42 |
| - assert(::abs(::log1p(BFloat16(0.0)) - ::log1p(0.0f)) <= threshold); |
43 |
| - assert(::abs(::log2(BFloat16(1000.0)) - ::log2(1000.0f)) <= threshold); |
44 |
| - assert(::abs(::expm1(BFloat16(1.0)) - ::expm1(1.0f)) <= threshold); |
45 |
| - assert(::abs(::cos(BFloat16(0.0)) - ::cos(0.0f)) <= threshold); |
46 |
| - assert(::abs(::sin(BFloat16(0.0)) - ::sin(0.0f)) <= threshold); |
47 |
| - assert(::abs(::sqrt(BFloat16(100.0)) - ::sqrt(100.0f)) <= threshold); |
48 |
| - assert(::abs(::ceil(BFloat16(2.4)) - ::ceil(2.4f)) <= threshold); |
49 |
| - assert(::abs(::floor(BFloat16(2.7)) - ::floor(2.7f)) <= threshold); |
50 |
| - assert(::abs(::trunc(BFloat16(2.7)) - ::trunc(2.7f)) <= threshold); |
51 |
| - assert(::abs(::acos(BFloat16(-1.0)) - ::acos(-1.0f)) <= threshold); |
52 |
| - assert(::abs(::cosh(BFloat16(1.0)) - ::cosh(1.0f)) <= threshold); |
53 |
| - assert(::abs(::acosh(BFloat16(1.0)) - ::acosh(1.0f)) <= threshold); |
54 |
| - assert(::abs(::acosh(BFloat16(1.0)) - ::acosh(1.0f)) <= threshold); |
55 |
| - assert(::abs(::asinh(BFloat16(1.0)) - ::asinh(1.0f)) <= threshold); |
56 |
| - assert(::abs(::atanh(BFloat16(0.5)) - ::atanh(0.5f)) <= threshold); |
57 |
| - assert(::abs(::asin(BFloat16(1.0)) - ::asin(1.0f)) <= threshold); |
58 |
| - assert(::abs(::sinh(BFloat16(1.0)) - ::sinh(1.0f)) <= threshold); |
59 |
| - assert(::abs(::asinh(BFloat16(1.0)) - ::asinh(1.0f)) <= threshold); |
60 |
| - assert(::abs(::tan(BFloat16(0.0)) - ::tan(0.0f)) <= threshold); |
61 |
| - assert(::abs(::atan(BFloat16(1.0)) - ::atan(1.0f)) <= threshold); |
62 |
| - assert(::abs(::tanh(BFloat16(1.0)) - ::tanh(1.0f)) <= threshold); |
63 |
| - assert(::abs(::erf(BFloat16(10.0)) - ::erf(10.0f)) <= threshold); |
64 |
| - assert(::abs(::erfc(BFloat16(10.0)) - ::erfc(10.0f)) <= threshold); |
65 |
| - assert(::abs(::abs(BFloat16(-3.0)) - ::abs(-3.0f)) <= threshold); |
66 |
| - assert(::abs(::round(BFloat16(2.3)) - ::round(2.3f)) <= threshold); |
67 |
| - assert(::abs(::pow(BFloat16(2.0), BFloat16(10.0)) - ::pow(2.0f, 10.0f)) <= threshold); |
68 |
| - assert( |
69 |
| - ::abs(::atan2(BFloat16(7.0), BFloat16(0.0)) - ::atan2(7.0f, 0.0f)) <= threshold); |
70 |
| - // note: can't use namespace on isnan and isinf in device code |
71 |
| -#ifdef _MSC_VER |
72 |
| - // Windows requires this explicit conversion. The reason is unclear |
73 |
| - // related issue with clang: https://reviews.llvm.org/D37906 |
74 |
| - assert(::abs(::isnan((float)BFloat16(0.0)) - ::isnan(0.0f)) <= threshold); |
75 |
| - assert(::abs(::isinf((float)BFloat16(0.0)) - ::isinf(0.0f)) <= threshold); |
76 |
| -#else |
77 |
| - assert(::abs(::isnan(BFloat16(0.0)) - ::isnan(0.0f)) <= threshold); |
78 |
| - assert(::abs(::isinf(BFloat16(0.0)) - ::isinf(0.0f)) <= threshold); |
79 |
| -#endif |
80 |
| -} |
81 |
| - |
82 |
| -__global__ void kernel(){ |
83 |
| - test(); |
84 | 23 | }
|
85 | 24 |
|
86 | 25 | void launch_function(){
|
|
0 commit comments