From 9b6ed22d38de5d58fb482f7a2b5f5d3ec7b83192 Mon Sep 17 00:00:00 2001 From: Dounia Date: Tue, 29 Mar 2022 08:07:23 -0700 Subject: [PATCH 1/8] [SYCL][Matrix] Add support for tf32 type Signed-off-by: Dounia --- .../sycl/ext/oneapi/matrix/matrix-jit.hpp | 44 ++++- sycl/test/matrix/matrix-tf32-test.cpp | 164 ++++++++++++++++++ 2 files changed, 203 insertions(+), 5 deletions(-) create mode 100644 sycl/test/matrix/matrix-tf32-test.cpp diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index fa467934c0cfd..cb6047bafbb46 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -70,6 +70,27 @@ struct joint_matrix { } }; +// class tf32 should not hold actual data. It is a tag type only, an empty class +// with no member variables. Morally, it is equivalent to an enumeration--it +// just uses the type system to communicate the desired accuracy of arithmetic +// computations. Users can't construct a tf32 +namespace precision { +class tf32 {}; +} // namespace precision + +// Differentiating between the "element type" and the "storage element type" +template struct helper_traits { + typedef T element_type; + typedef T storage_element_type; + typedef T fill_argument_type; +}; + +template <> struct helper_traits { + typedef precision::tf32 element_type; + typedef float storage_element_type; + typedef float fill_argument_type; +}; + template @@ -231,12 +252,16 @@ class wi_element { std::size_t idx; public: + typedef typename helper_traits::storage_element_type storage_element_type; wi_element(joint_matrix &Mat, std::size_t i) : M(Mat), idx(i) {} - operator T() { + operator storage_element_type() { #ifdef __SYCL_DEVICE_ONLY__ - return __spirv_VectorExtractDynamic(M.spvm, idx); + // TODO: __spirv_VectorExtractDynamic should also return + // storage_element_type + T elem = __spirv_VectorExtractDynamic(M.spvm, idx); + return reinterpret_cast(elem); #else throw runtime_error("joint matrix is not supported on host device.", PI_INVALID_DEVICE); @@ -245,7 +270,11 @@ class wi_element { explicit operator bool() { #ifdef __SYCL_DEVICE_ONLY__ - return __spirv_VectorExtractDynamic(M.spvm, idx) != static_cast(0); + // TODO: __spirv_VectorExtractDynamic should also return + // storage_element_type + T elem = __spirv_VectorExtractDynamic(M.spvm, idx); + storage_element_type elems = reinterpret_cast(elem); + return elems != static_cast(0); #else throw runtime_error("joint matrix is not supported on host device.", PI_INVALID_DEVICE); @@ -277,12 +306,17 @@ class wi_element { } #if __SYCL_DEVICE_ONLY__ + // TODO: __spirv_VectorInsertDynamic should take storage element type as + // argument #define OP(op) \ template wi_element &operator op##=(const T2 &rhs) { \ + T elem = __spirv_VectorExtractDynamic(M.spvm, idx); \ + storage_element_type elems = \ + reinterpret_cast(elem); \ M.spvm = __spirv_VectorInsertDynamic( \ M.spvm, \ - static_cast(__spirv_VectorExtractDynamic(M.spvm, idx) \ - op static_cast(rhs)), \ + static_cast( \ + elems op static_cast(rhs)), \ idx); \ return *this; \ } diff --git a/sycl/test/matrix/matrix-tf32-test.cpp b/sycl/test/matrix/matrix-tf32-test.cpp new file mode 100644 index 0000000000000..0f6a3d40a6973 --- /dev/null +++ b/sycl/test/matrix/matrix-tf32-test.cpp @@ -0,0 +1,164 @@ +// RUN: %clangxx -fsycl -O2 %s -o %t.out + +#include +#if (SYCL_EXT_ONEAPI_MATRIX == 2) +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 8 + +#define TM 8 +#define TN SG_SZ +#define TK 16 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void matrix_multiply(big_matrix &C, + big_matrix &A, + big_matrix &B) { + size_t M = NUM_ROWS_C; + size_t N = NUM_COLS_C; + size_t K = NUM_COLS_A; + + assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B); + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + // buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufC((float *)C.get_data(), range<2>(M, N)); + + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [ accA, accB, accC, M, N, K ](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] + + { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + joint_matrix sub_b( + sg); + joint_matrix sub_c(sg); + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K; k += TK) { + joint_matrix_load(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * K + k, K, + matrix_layout::row_major); + // Assume we alreay in vnni format. + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k) * (N) + + sg_starty / SG_SZ * TN, + N, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + float elem = wi_slice_a[i]; + // TODO: OP= is buggy: __spirv_VectorInsertDynamic should take + // storage element type as argument + // wi_slice_a[i] *= 2; + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +static constexpr size_t MATRIX_K = TK * 2; +precision::tf32 A[MATRIX_M][MATRIX_K]; +precision::tf32 B[MATRIX_K][MATRIX_N]; +float C[MATRIX_M][MATRIX_N]; +float D[MATRIX_M][MATRIX_N]; + +precision::tf32 make_tf32(float x) { + uint32_t y = reinterpret_cast(x); + y += 0x1000u; + precision::tf32 res = reinterpret_cast(y); + return res; +} + +void matrix_multiply_ref(float *A_mem, float *B_mem, float *C_mem, int M, int N, + int K) { + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + float va = *(float *)(A_mem + m * K + k); + float vb = *(float *)(B_mem + k * N + n); + float acc = *((float *)(C_mem + m * N + n)); + *((float *)(C_mem + m * N + n)) = va * vb; + } + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + A[i][j] = make_tf32(1.0f * (i + j)); + } + } + for (int i = 0; i < MATRIX_K / 2; i++) { + for (int j = 0; j < MATRIX_N * 2; j++) { + B[i][j] = make_tf32(2.0f * i + 3.0f * j); + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 1.0; + D[i][j] = 1.0; + } + } + + big_matrix MC((float *)&C); + big_matrix MD((float *)&D); + big_matrix MA((precision::tf32 *)&A); + big_matrix MB((precision::tf32 *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((float *)A, (float *)B, (float *)D, MATRIX_M, MATRIX_N, + MATRIX_K / 2); + + bool res = true; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + if (C[i][j] != D[i][j]) + res = false; + } + } + if (res) + std::cout << "passed\n"; + else + std::cout << "failed\n"; +} From 016523f30d2e57d51c51e9a41206f3726f5cb2ed Mon Sep 17 00:00:00 2001 From: Dounia Date: Tue, 29 Mar 2022 08:13:16 -0700 Subject: [PATCH 2/8] [SYCL][Matrix] Add a comment about conversion function --- sycl/test/matrix/matrix-tf32-test.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/sycl/test/matrix/matrix-tf32-test.cpp b/sycl/test/matrix/matrix-tf32-test.cpp index 0f6a3d40a6973..0c8f1d6b2d777 100644 --- a/sycl/test/matrix/matrix-tf32-test.cpp +++ b/sycl/test/matrix/matrix-tf32-test.cpp @@ -104,6 +104,7 @@ precision::tf32 B[MATRIX_K][MATRIX_N]; float C[MATRIX_M][MATRIX_N]; float D[MATRIX_M][MATRIX_N]; +// this is a hack and should be replaced with a spirv function precision::tf32 make_tf32(float x) { uint32_t y = reinterpret_cast(x); y += 0x1000u; From 108f04abc9c2cbbed4f5aa1deb9fbed805d92990 Mon Sep 17 00:00:00 2001 From: Dounia Date: Tue, 29 Mar 2022 08:19:20 -0700 Subject: [PATCH 3/8] [SYCL][Matrix] minor formatting --- sycl/test/matrix/matrix-tf32-test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/test/matrix/matrix-tf32-test.cpp b/sycl/test/matrix/matrix-tf32-test.cpp index 0c8f1d6b2d777..2027194befdb8 100644 --- a/sycl/test/matrix/matrix-tf32-test.cpp +++ b/sycl/test/matrix/matrix-tf32-test.cpp @@ -49,7 +49,7 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [ accA, accB, accC, M, N, K ](nd_item<2> spmd_item) + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { From 98480abe82977fc60be2b125f8ecbb0593049280 Mon Sep 17 00:00:00 2001 From: Dounia Date: Tue, 7 Jun 2022 13:21:10 -0700 Subject: [PATCH 4/8] tf32 cannot be constructed, change load,store, and slicing signatures --- .../sycl/ext/oneapi/matrix/matrix-jit.hpp | 60 +++++++++-------- sycl/test/matrix/matrix-tf32-test.cpp | 64 +++++++++++-------- 2 files changed, 67 insertions(+), 57 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index cb6047bafbb46..c8e04d977a2a4 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -70,63 +70,65 @@ struct joint_matrix { } }; -// class tf32 should not hold actual data. It is a tag type only, an empty class -// with no member variables. Morally, it is equivalent to an enumeration--it -// just uses the type system to communicate the desired accuracy of arithmetic -// computations. Users can't construct a tf32 +// class tf32 should not hold actual data. It is a tag type only, an empty +// class with no member variables. Morally, it is equivalent to an +// enumeration--it just uses the type system to communicate the desired +// accuracy of arithmetic computations. Users can't construct a tf32 namespace precision { class tf32 {}; } // namespace precision // Differentiating between the "element type" and the "storage element type" template struct helper_traits { - typedef T element_type; - typedef T storage_element_type; - typedef T fill_argument_type; + using element_type = T; + using storage_element_type = T; + using fill_argument_type = T; }; template <> struct helper_traits { - typedef precision::tf32 element_type; - typedef float storage_element_type; - typedef float fill_argument_type; + using element_type = precision::tf32; + using storage_element_type = float; + using fill_argument_type = float; }; -template inline __SYCL_ALWAYS_INLINE void joint_matrix_load(Group sg, - joint_matrix &res, - multi_ptr src, size_t stride, matrix_layout MemL) { + joint_matrix &res, + multi_ptr src, size_t stride, matrix_layout MemL) { #ifdef __SYCL_DEVICE_ONLY__ - T *Ptr = src.get(); + // For non tf32 case, check that Te is the same that Tm + Tm *Ptr = src.get(); + using Ts = typename helper_traits::storage_element_type; switch (MemL) { default: assert(false && "Invalid Memory Layout!"); case matrix_layout::row_major: res.spvm = - __spirv_JointMatrixLoadINTEL::value>( Ptr, stride, __spv::MatrixLayout::RowMajor, spv_scope_traits::value); break; case matrix_layout::col_major: res.spvm = - __spirv_JointMatrixLoadINTEL::value>( Ptr, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; case matrix_layout::packed_a: res.spvm = - __spirv_JointMatrixLoadINTEL::value>( Ptr, stride, __spv::MatrixLayout::PackedA, spv_scope_traits::value); break; case matrix_layout::packed_b: res.spvm = - __spirv_JointMatrixLoadINTEL::value>( Ptr, stride, __spv::MatrixLayout::PackedB, spv_scope_traits::value); @@ -252,7 +254,7 @@ class wi_element { std::size_t idx; public: - typedef typename helper_traits::storage_element_type storage_element_type; + using storage_element_type = typename helper_traits::storage_element_type; wi_element(joint_matrix &Mat, std::size_t i) : M(Mat), idx(i) {} @@ -306,8 +308,6 @@ class wi_element { } #if __SYCL_DEVICE_ONLY__ - // TODO: __spirv_VectorInsertDynamic should take storage element type as - // argument #define OP(op) \ template wi_element &operator op##=(const T2 &rhs) { \ T elem = __spirv_VectorExtractDynamic(M.spvm, idx); \ @@ -315,9 +315,7 @@ class wi_element { reinterpret_cast(elem); \ M.spvm = __spirv_VectorInsertDynamic( \ M.spvm, \ - static_cast( \ - elems op static_cast(rhs)), \ - idx); \ + static_cast(elems op static_cast(rhs)), idx); \ return *this; \ } #else // __SYCL_DEVICE_ONLY__ @@ -337,10 +335,10 @@ class wi_element { // Note that similarly to the other matrix functions, uint16_t is used here to // represent bf16 type. Since the AMX and DPAS implementations don't support -// uint16_t, this interpretation is possible. This design choice was made before -// the introduction of SYCL experimental bfloat16 type. Our plan is to move -// towards using the SYCL bfloat16. But since it is still experimental, we will -// probably keep both uint16 interpretation and SYCL bfloat16. +// uint16_t, this interpretation is possible. This design choice was made +// before the introduction of SYCL experimental bfloat16 type. Our plan is to +// move towards using the SYCL bfloat16. But since it is still experimental, +// we will probably keep both uint16 interpretation and SYCL bfloat16. template class wi_element { joint_matrix &M; @@ -395,8 +393,8 @@ class wi_element { // We use here the following functions for conversion (bf16=>fp32 and // fp32=>bf16). This is a workaround until we are able to use - // __spirv_ConvertFToBF16INTEL and __spirv_ConvertBF16ToFINTEL once these are - // supported in the CPU backend + // __spirv_ConvertFToBF16INTEL and __spirv_ConvertBF16ToFINTEL once these + // are supported in the CPU backend static float make_fp32(uint16_t x) { unsigned int y = x; y = y << 16; diff --git a/sycl/test/matrix/matrix-tf32-test.cpp b/sycl/test/matrix/matrix-tf32-test.cpp index 2027194befdb8..78775af1a0218 100644 --- a/sycl/test/matrix/matrix-tf32-test.cpp +++ b/sycl/test/matrix/matrix-tf32-test.cpp @@ -1,13 +1,13 @@ // RUN: %clangxx -fsycl -O2 %s -o %t.out -#include +#include #if (SYCL_EXT_ONEAPI_MATRIX == 2) #include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; -#define SG_SZ 8 +auto constexpr SG_SZ = 8; #define TM 8 #define TN SG_SZ @@ -23,6 +23,15 @@ template struct big_matrix { big_matrix(T *data) : mat(data) {} }; +// this should be replaced with a DPC++ and spirv functions +float round_to_tf32(float a) { + uint32_t tmp_uint = reinterpret_cast(a); + tmp_uint += 0x1000u; // Round up the 13th last bit + tmp_uint &= 0xFFFFE000u; // Zero out the bottom 13 bits + float ret = reinterpret_cast(tmp_uint); + return ret; +} + template @@ -37,8 +46,8 @@ void matrix_multiply(big_matrix &C, size_t NDRangeM = M / TM; size_t NDRangeN = N / TN; // buffer bufA(A.get_data(), range<2>(M, K)); - buffer bufA(A.get_data(), range<2>(M, K)); - buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); buffer bufC((float *)C.get_data(), range<2>(M, N)); queue q; @@ -48,9 +57,8 @@ void matrix_multiply(big_matrix &C, auto accB = bufB.get_access(cgh); cgh.parallel_for( - nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, K](nd_item<2> spmd_item) - [[intel::reqd_sub_group_size(SG_SZ)]] + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), [= + ](nd_item<2> spmd_item)[[intel::reqd_sub_group_size(SG_SZ)]] { // The submatrix API has to be accessed by all the workitems in a @@ -79,6 +87,16 @@ void matrix_multiply(big_matrix &C, accB.get_pointer() + (k) * (N) + sg_starty / SG_SZ * TN, N, matrix_layout::packed_b); + // If no rounding to tf32 function is called, the mad function will + // work on truncated floats. + // TODO: change signature of __spirv_VectorInsertDynamic to have + // two types: matrix type can be different from value type + for (int i = 0; i < sub_a.get_wi_data().length(); i++) { + sub_a.get_wi_data()[i] = round_to_tf32(sub_a.get_wi_data()[i]); + } + for (int i = 0; i < sub_b.get_wi_data().length(); i++) { + sub_b.get_wi_data()[i] = round_to_tf32(sub_b.get_wi_data()[i]); + } sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_slice_a = sub_a.get_wi_data(); @@ -93,34 +111,27 @@ void matrix_multiply(big_matrix &C, sg_starty / SG_SZ * TN, N, matrix_layout::row_major); }); // parallel for - }).wait(); + }) + .wait(); } static constexpr size_t MATRIX_M = TM * 2; static constexpr size_t MATRIX_N = TN * 2; static constexpr size_t MATRIX_K = TK * 2; -precision::tf32 A[MATRIX_M][MATRIX_K]; -precision::tf32 B[MATRIX_K][MATRIX_N]; +float A[MATRIX_M][MATRIX_K]; +float B[MATRIX_K][MATRIX_N]; float C[MATRIX_M][MATRIX_N]; float D[MATRIX_M][MATRIX_N]; -// this is a hack and should be replaced with a spirv function -precision::tf32 make_tf32(float x) { - uint32_t y = reinterpret_cast(x); - y += 0x1000u; - precision::tf32 res = reinterpret_cast(y); - return res; -} - void matrix_multiply_ref(float *A_mem, float *B_mem, float *C_mem, int M, int N, int K) { for (int m = 0; m < M; m++) for (int n = 0; n < N; n++) { for (int k = 0; k < K; k++) { - float va = *(float *)(A_mem + m * K + k); - float vb = *(float *)(B_mem + k * N + n); - float acc = *((float *)(C_mem + m * N + n)); - *((float *)(C_mem + m * N + n)) = va * vb; + float va = A_mem[m * K + k]; + float vb = B_mem[k * N + n]; + float acc = C_mem[m * N + n]; + C_mem[m * N + n] = va * vb; } } } @@ -128,12 +139,12 @@ void matrix_multiply_ref(float *A_mem, float *B_mem, float *C_mem, int M, int N, int main() { for (int i = 0; i < MATRIX_M; i++) { for (int j = 0; j < MATRIX_K; j++) { - A[i][j] = make_tf32(1.0f * (i + j)); + A[i][j] = 1.0f * (i + j); } } for (int i = 0; i < MATRIX_K / 2; i++) { for (int j = 0; j < MATRIX_N * 2; j++) { - B[i][j] = make_tf32(2.0f * i + 3.0f * j); + B[i][j] = 2.0f * i + 3.0f * j; } } for (int i = 0; i < MATRIX_M; i++) { @@ -145,8 +156,8 @@ int main() { big_matrix MC((float *)&C); big_matrix MD((float *)&D); - big_matrix MA((precision::tf32 *)&A); - big_matrix MB((precision::tf32 *)&B); + big_matrix MA((float *)&A); + big_matrix MB((float *)&B); matrix_multiply(MC, MA, MB); matrix_multiply_ref((float *)A, (float *)B, (float *)D, MATRIX_M, MATRIX_N, MATRIX_K / 2); @@ -163,3 +174,4 @@ int main() { else std::cout << "failed\n"; } +#endif // (SYCL_EXT_ONEAPI_MATRIX == 2) From 4d7c661754b899d566d57fe994389fb1ca5aa25c Mon Sep 17 00:00:00 2001 From: Dounia Khaldi Date: Fri, 19 Aug 2022 14:19:54 -0700 Subject: [PATCH 5/8] Change the signatures of extract and insert dynamic to return storage type fp32 instead of tf32 --- sycl/include/CL/__spirv/spirv_ops.hpp | 14 +++--- .../sycl/ext/oneapi/matrix/matrix-jit.hpp | 46 ++++++++++--------- sycl/test/matrix/matrix-tf32-test.cpp | 11 ++--- 3 files changed, 36 insertions(+), 35 deletions(-) diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index 9bfbb759271fa..b24f58e8ccd66 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -22,11 +22,11 @@ #endif #ifdef __SYCL_DEVICE_ONLY__ -template extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * -__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride, +__spirv_JointMatrixLoadINTEL(Ts *Ptr, std::size_t Stride, __spv::MatrixLayout Layout = L, __spv::Scope::Flag Sc = S, int MemOperand = 0); @@ -97,16 +97,18 @@ template *); -template -extern SYCL_EXTERNAL T __spirv_VectorExtractDynamic( +extern SYCL_EXTERNAL Ts __spirv_VectorExtractDynamic( __spv::__spirv_JointMatrixINTEL *, size_t i); -template extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL *, - T val, size_t i); + Ts val, size_t i); #ifndef __SPIRV_BUILTIN_DECLARATIONS__ #error \ diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index c8e04d977a2a4..d1a545737ebbd 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -70,10 +70,10 @@ struct joint_matrix { } }; -// class tf32 should not hold actual data. It is a tag type only, an empty -// class with no member variables. Morally, it is equivalent to an -// enumeration--it just uses the type system to communicate the desired -// accuracy of arithmetic computations. Users can't construct a tf32 +// class tf32 should not hold actual data. It is a tag type only, an empty class +// with no member variables. Morally, it is equivalent to an enumeration--it +// just uses the type system to communicate the desired accuracy of arithmetic +// computations. Users can't construct a tf32 namespace precision { class tf32 {}; } // namespace precision @@ -260,10 +260,10 @@ class wi_element { : M(Mat), idx(i) {} operator storage_element_type() { #ifdef __SYCL_DEVICE_ONLY__ - // TODO: __spirv_VectorExtractDynamic should also return - // storage_element_type - T elem = __spirv_VectorExtractDynamic(M.spvm, idx); - return reinterpret_cast(elem); + // __spirv_VectorExtractDynamic returns storage_element_type + storage_element_type elem = + __spirv_VectorExtractDynamic(M.spvm, idx); + return elem; #else throw runtime_error("joint matrix is not supported on host device.", PI_INVALID_DEVICE); @@ -272,10 +272,9 @@ class wi_element { explicit operator bool() { #ifdef __SYCL_DEVICE_ONLY__ - // TODO: __spirv_VectorExtractDynamic should also return - // storage_element_type - T elem = __spirv_VectorExtractDynamic(M.spvm, idx); - storage_element_type elems = reinterpret_cast(elem); + // __spirv_VectorExtractDynamic returns storage_element_type + storage_element_type elems = + __spirv_VectorExtractDynamic(M.spvm, idx); return elems != static_cast(0); #else throw runtime_error("joint matrix is not supported on host device.", @@ -285,7 +284,9 @@ class wi_element { template wi_element &operator=(const T2 &rhs) { #ifdef __SYCL_DEVICE_ONLY__ - M.spvm = __spirv_VectorInsertDynamic(M.spvm, static_cast(rhs), idx); + // __spirv_VectorInsertDynamic takes storage_element_type as argument + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, static_cast(rhs), idx); return *this; #else (void)rhs; @@ -310,12 +311,13 @@ class wi_element { #if __SYCL_DEVICE_ONLY__ #define OP(op) \ template wi_element &operator op##=(const T2 &rhs) { \ - T elem = __spirv_VectorExtractDynamic(M.spvm, idx); \ storage_element_type elems = \ - reinterpret_cast(elem); \ + __spirv_VectorExtractDynamic(M.spvm, idx); \ M.spvm = __spirv_VectorInsertDynamic( \ M.spvm, \ - static_cast(elems op static_cast(rhs)), idx); \ + static_cast( \ + elems op static_cast(rhs)), \ + idx); \ return *this; \ } #else // __SYCL_DEVICE_ONLY__ @@ -335,10 +337,10 @@ class wi_element { // Note that similarly to the other matrix functions, uint16_t is used here to // represent bf16 type. Since the AMX and DPAS implementations don't support -// uint16_t, this interpretation is possible. This design choice was made -// before the introduction of SYCL experimental bfloat16 type. Our plan is to -// move towards using the SYCL bfloat16. But since it is still experimental, -// we will probably keep both uint16 interpretation and SYCL bfloat16. +// uint16_t, this interpretation is possible. This design choice was made before +// the introduction of SYCL experimental bfloat16 type. Our plan is to move +// towards using the SYCL bfloat16. But since it is still experimental, we will +// probably keep both uint16 interpretation and SYCL bfloat16. template class wi_element { joint_matrix &M; @@ -393,8 +395,8 @@ class wi_element { // We use here the following functions for conversion (bf16=>fp32 and // fp32=>bf16). This is a workaround until we are able to use - // __spirv_ConvertFToBF16INTEL and __spirv_ConvertBF16ToFINTEL once these - // are supported in the CPU backend + // __spirv_ConvertFToBF16INTEL and __spirv_ConvertBF16ToFINTEL once these are + // supported in the CPU backend static float make_fp32(uint16_t x) { unsigned int y = x; y = y << 16; diff --git a/sycl/test/matrix/matrix-tf32-test.cpp b/sycl/test/matrix/matrix-tf32-test.cpp index 78775af1a0218..83d7735fee9da 100644 --- a/sycl/test/matrix/matrix-tf32-test.cpp +++ b/sycl/test/matrix/matrix-tf32-test.cpp @@ -57,8 +57,8 @@ void matrix_multiply(big_matrix &C, auto accB = bufB.get_access(cgh); cgh.parallel_for( - nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), [= - ](nd_item<2> spmd_item)[[intel::reqd_sub_group_size(SG_SZ)]] + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { // The submatrix API has to be accessed by all the workitems in a @@ -102,17 +102,14 @@ void matrix_multiply(big_matrix &C, auto wi_slice_a = sub_a.get_wi_data(); for (int i = 0; i < wi_slice_a.length(); i++) { float elem = wi_slice_a[i]; - // TODO: OP= is buggy: __spirv_VectorInsertDynamic should take - // storage element type as argument - // wi_slice_a[i] *= 2; + wi_slice_a[i] *= 2; } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, matrix_layout::row_major); }); // parallel for - }) - .wait(); + }).wait(); } static constexpr size_t MATRIX_M = TM * 2; From e349b711e7158d8cfd2890a8185b04a33137639c Mon Sep 17 00:00:00 2001 From: Dounia Khaldi Date: Mon, 22 Aug 2022 11:15:43 -0700 Subject: [PATCH 6/8] make it illegal to construct tf32 class type --- sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index d1a545737ebbd..3a9b3a578e58c 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -75,7 +75,9 @@ struct joint_matrix { // just uses the type system to communicate the desired accuracy of arithmetic // computations. Users can't construct a tf32 namespace precision { -class tf32 {}; +class tf32 { + tf32() = delete; +}; } // namespace precision // Differentiating between the "element type" and the "storage element type" From a910a252b2ce1a23900b1ee9a9b9e88e7c56a44b Mon Sep 17 00:00:00 2001 From: Dounia Date: Mon, 22 Aug 2022 11:51:35 -0700 Subject: [PATCH 7/8] formatting --- sycl/test/matrix/matrix-tf32-test.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sycl/test/matrix/matrix-tf32-test.cpp b/sycl/test/matrix/matrix-tf32-test.cpp index 83d7735fee9da..3de4b0c005534 100644 --- a/sycl/test/matrix/matrix-tf32-test.cpp +++ b/sycl/test/matrix/matrix-tf32-test.cpp @@ -57,8 +57,8 @@ void matrix_multiply(big_matrix &C, auto accB = bufB.get_access(cgh); cgh.parallel_for( - nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), [= + ](nd_item<2> spmd_item)[[intel::reqd_sub_group_size(SG_SZ)]] { // The submatrix API has to be accessed by all the workitems in a @@ -109,7 +109,8 @@ void matrix_multiply(big_matrix &C, sg_starty / SG_SZ * TN, N, matrix_layout::row_major); }); // parallel for - }).wait(); + }) + .wait(); } static constexpr size_t MATRIX_M = TM * 2; From 682dff9465e4a93eb5adbd8aba32a8e3629381b5 Mon Sep 17 00:00:00 2001 From: Dounia Date: Tue, 23 Aug 2022 12:20:42 -0700 Subject: [PATCH 8/8] update branch --- sycl/test/matrix/matrix-tf32-test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/test/matrix/matrix-tf32-test.cpp b/sycl/test/matrix/matrix-tf32-test.cpp index 3de4b0c005534..abc616edb2545 100644 --- a/sycl/test/matrix/matrix-tf32-test.cpp +++ b/sycl/test/matrix/matrix-tf32-test.cpp @@ -61,7 +61,7 @@ void matrix_multiply(big_matrix &C, ](nd_item<2> spmd_item)[[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a + // The matrix API has to be accessed by all the workitems in a // subgroup these functions will be called once by the subgroup no // code divergence between the workitems const auto global_idx = spmd_item.get_global_id(0);