Skip to content

Commit a91b68d

Browse files
author
Chao Liu
committed
DynamicBuffer, StaticBuffer, amd_buffer_load support customized value for invalid element
1 parent 2cbabbb commit a91b68d

9 files changed

+176
-74
lines changed

composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
133133
static_assert(WPerThread % WoPerThreadSubC == 0, "");
134134

135135
// thread A buffer for GEMM
136-
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize()>
136+
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
137137
a_thread_buf;
138138

139139
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,

composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
227227
// register allocation for output
228228
StaticBuffer<AddressSpaceEnum_t::Vgpr,
229229
FloatAcc,
230-
c_k_n_ho_wo_thread_desc.GetElementSpaceSize()>
230+
c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
231+
true>
231232
c_thread_buf;
232233

233234
// initialize output thread tensor
@@ -251,7 +252,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
251252
// double regsiter buffer for b
252253
StaticBuffer<AddressSpaceEnum_t::Vgpr,
253254
FloatAB,
254-
b_e_n_ho_wo_thread_desc.GetElementSpaceSize()>
255+
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
256+
true>
255257
b_thread_even_buf, b_thread_odd_buf;
256258

257259
// LDS double buffer: preload data

composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
402402

403403
StaticBuffer<AddressSpaceEnum_t::Vgpr,
404404
vector_type<FloatAcc, BlkSize>,
405-
c_mr_nr_blk_desc.GetElementSpaceSize()>
405+
c_mr_nr_blk_desc.GetElementSpaceSize(),
406+
true>
406407
c_thread_buf;
407408

408409
// LDS allocation for A and B: be careful of alignment
@@ -493,7 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
493494
Number<M2>{},
494495
Number<1>{}));
495496

496-
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize()>
497+
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize(), true>
497498
c_blk_buf_;
498499

499500
static_for<0, MRepeat, 1>{}([&](auto mr_i) {

composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1242,7 +1242,7 @@ struct ThreadwiseTensorSliceTransfer_v3
12421242

12431243
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
12441244

1245-
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_;
1245+
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
12461246

12471247
SrcCoord src_coord_;
12481248
DstCoord dst_coord_;

composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
602602

603603
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
604604

605-
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_;
605+
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
606606

607607
SrcCoord src_coord_;
608608
DstCoord dst_coord_;

composable_kernel/include/utility/amd_buffer_addressing.hpp

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,25 @@ union BufferResource
1010
{
1111
// 128 bit SGPRs to supply buffer resource in buffer instructions
1212
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
13-
int32x4_t data;
13+
int32x4_t content;
1414
StaticallyIndexedArray<T*, 2> address;
1515
StaticallyIndexedArray<int32_t, 4> range;
1616
StaticallyIndexedArray<int32_t, 4> config;
1717
};
1818

1919
template <typename T>
20-
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size)
20+
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size)
2121
{
2222
BufferResource<T> wave_buffer_resource;
2323

2424
// wavewise base address (64 bit)
2525
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
2626
// wavewise range (32 bit)
27-
wave_buffer_resource.range(Number<2>{}) = data_space_size * sizeof(T);
27+
wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T);
2828
// wavewise setting (32 bit)
2929
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
3030

31-
return wave_buffer_resource.data;
31+
return wave_buffer_resource.content;
3232
}
3333

3434
// load
@@ -204,10 +204,9 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
204204
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
205205

206206
template <typename T, index_t N>
207-
__device__ typename vector_type<T, N>::type
208-
amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
209-
index_t src_thread_addr_offset,
210-
index_t src_wave_addr_offset)
207+
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
208+
index_t src_thread_addr_offset,
209+
index_t src_wave_addr_offset)
211210
{
212211
static_assert(
213212
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
@@ -412,10 +411,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
412411
}
413412

414413
template <typename T, index_t N>
415-
__device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type src_thread_data,
416-
int32x4_t dst_wave_buffer_resource,
417-
index_t dst_thread_addr_offset,
418-
index_t dst_wave_addr_offset)
414+
__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
415+
int32x4_t dst_wave_buffer_resource,
416+
index_t dst_thread_addr_offset,
417+
index_t dst_wave_addr_offset)
419418
{
420419
static_assert(
421420
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
@@ -584,67 +583,95 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
584583

585584
// buffer_load requires:
586585
// 1) p_src_wave must be in global memory space
587-
// 2) p_src_wave to be a wavewise pointer.
586+
// 2) p_src_wave must be a wavewise pointer.
588587
// It is user's responsibility to make sure that is true.
589588
template <typename T, index_t N>
590589
__device__ typename vector_type_maker<T, N>::type::type
591-
amd_buffer_load_v2(const T* p_src_wave,
592-
index_t src_thread_data_offset,
593-
bool src_thread_data_valid,
594-
index_t src_element_space)
590+
amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave,
591+
index_t src_thread_element_offset,
592+
bool src_thread_element_valid,
593+
index_t src_element_space_size)
595594
{
596595
const int32x4_t src_wave_buffer_resource =
597-
make_wave_buffer_resource(p_src_wave, src_element_space);
596+
make_wave_buffer_resource(p_src_wave, src_element_space_size);
598597

599-
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T);
598+
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
599+
600+
using vector_t = typename vector_type_maker<T, N>::type::type;
601+
using scalar_t = typename scalar_type<vector_t>::type;
600602

601-
using vector_t = typename vector_type_maker<T, N>::type::type;
602-
using scalar_t = typename scalar_type<vector_t>::type;
603603
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
604604

605605
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
606-
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
606+
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff;
607607

608-
return amd_buffer_load_impl_v2<scalar_t, vector_size>(
608+
return amd_buffer_load_impl<scalar_t, vector_size>(
609609
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
610610
#else
611-
vector_t tmp = amd_buffer_load_impl_v2<scalar_t, vector_size>(
611+
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
612612
src_wave_buffer_resource, src_thread_addr_offset, 0);
613613

614-
return src_thread_data_valid ? tmp : vector_t(0);
614+
return src_thread_element_valid ? tmp : vector_t(0);
615615
#endif
616616
}
617617

618+
// buffer_load requires:
619+
// 1) p_src_wave must be in global memory space
620+
// 2) p_src_wave must be a wavewise pointer.
621+
// It is user's responsibility to make sure that is true.
622+
template <typename T, index_t N>
623+
__device__ typename vector_type_maker<T, N>::type::type
624+
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
625+
index_t src_thread_element_offset,
626+
bool src_thread_element_valid,
627+
index_t src_element_space_size,
628+
T customized_value)
629+
{
630+
const int32x4_t src_wave_buffer_resource =
631+
make_wave_buffer_resource(p_src_wave, src_element_space_size);
632+
633+
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
634+
635+
using vector_t = typename vector_type_maker<T, N>::type::type;
636+
using scalar_t = typename scalar_type<vector_t>::type;
637+
638+
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
639+
640+
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
641+
src_wave_buffer_resource, src_thread_addr_offset, 0);
642+
643+
return src_thread_element_valid ? tmp : vector_t(customized_value);
644+
}
645+
618646
// buffer_store requires:
619647
// 1) p_dst_wave must be global memory
620648
// 2) p_dst_wave to be a wavewise pointer.
621649
// It is user's responsibility to make sure that is true.
622650
template <typename T, index_t N>
623-
__device__ void
624-
amd_buffer_store_v2(const typename vector_type_maker<T, N>::type::type src_thread_data,
625-
T* p_dst_wave,
626-
const index_t dst_thread_data_offset,
627-
const bool dst_thread_data_valid,
628-
const index_t dst_element_space)
651+
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
652+
T* p_dst_wave,
653+
const index_t dst_thread_element_offset,
654+
const bool dst_thread_element_valid,
655+
const index_t dst_element_space_size)
629656
{
630657
const int32x4_t dst_wave_buffer_resource =
631-
make_wave_buffer_resource(p_dst_wave, dst_element_space);
658+
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
632659

633-
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T);
660+
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
634661

635662
using vector_t = typename vector_type_maker<T, N>::type::type;
636663
using scalar_t = typename scalar_type<vector_t>::type;
637664
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
638665

639666
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
640-
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
667+
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
641668

642-
amd_buffer_store_impl_v2<scalar_t, vector_size>(
669+
amd_buffer_store_impl<scalar_t, vector_size>(
643670
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
644671
#else
645-
if(dst_thread_data_valid)
672+
if(dst_thread_element_valid)
646673
{
647-
amd_buffer_store_impl_v2<scalar_t, vector_size>(
674+
amd_buffer_store_impl<scalar_t, vector_size>(
648675
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
649676
}
650677
#endif

0 commit comments

Comments
 (0)