@@ -10,25 +10,25 @@ union BufferResource
10
10
{
11
11
// 128 bit SGPRs to supply buffer resource in buffer instructions
12
12
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
13
- int32x4_t data ;
13
+ int32x4_t content ;
14
14
StaticallyIndexedArray<T*, 2 > address;
15
15
StaticallyIndexedArray<int32_t , 4 > range;
16
16
StaticallyIndexedArray<int32_t , 4 > config;
17
17
};
18
18
19
19
template <typename T>
20
- __device__ int32x4_t make_wave_buffer_resource (T* p_wave, index_t data_space_size )
20
+ __device__ int32x4_t make_wave_buffer_resource (T* p_wave, index_t element_space_size )
21
21
{
22
22
BufferResource<T> wave_buffer_resource;
23
23
24
24
// wavewise base address (64 bit)
25
25
wave_buffer_resource.address (Number<0 >{}) = const_cast <remove_cv_t <T>*>(p_wave);
26
26
// wavewise range (32 bit)
27
- wave_buffer_resource.range (Number<2 >{}) = data_space_size * sizeof (T);
27
+ wave_buffer_resource.range (Number<2 >{}) = element_space_size * sizeof (T);
28
28
// wavewise setting (32 bit)
29
29
wave_buffer_resource.config (Number<3 >{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
30
30
31
- return wave_buffer_resource.data ;
31
+ return wave_buffer_resource.content ;
32
32
}
33
33
34
34
// load
@@ -204,10 +204,9 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
204
204
index_t glc_slc) __asm(" llvm.amdgcn.raw.buffer.store.v4f32" );
205
205
206
206
template <typename T, index_t N>
207
- __device__ typename vector_type<T, N>::type
208
- amd_buffer_load_impl_v2 (int32x4_t src_wave_buffer_resource,
209
- index_t src_thread_addr_offset,
210
- index_t src_wave_addr_offset)
207
+ __device__ typename vector_type<T, N>::type amd_buffer_load_impl (int32x4_t src_wave_buffer_resource,
208
+ index_t src_thread_addr_offset,
209
+ index_t src_wave_addr_offset)
211
210
{
212
211
static_assert (
213
212
(is_same<T, float >::value && (N == 1 || N == 2 || N == 4 || N == 8 )) ||
@@ -412,10 +411,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
412
411
}
413
412
414
413
template <typename T, index_t N>
415
- __device__ void amd_buffer_store_impl_v2 (const typename vector_type<T, N>::type src_thread_data,
416
- int32x4_t dst_wave_buffer_resource,
417
- index_t dst_thread_addr_offset,
418
- index_t dst_wave_addr_offset)
414
+ __device__ void amd_buffer_store_impl (const typename vector_type<T, N>::type src_thread_data,
415
+ int32x4_t dst_wave_buffer_resource,
416
+ index_t dst_thread_addr_offset,
417
+ index_t dst_wave_addr_offset)
419
418
{
420
419
static_assert (
421
420
(is_same<T, float >::value && (N == 1 || N == 2 || N == 4 )) ||
@@ -584,67 +583,95 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
584
583
585
584
// buffer_load requires:
586
585
// 1) p_src_wave must be in global memory space
587
- // 2) p_src_wave to be a wavewise pointer.
586
+ // 2) p_src_wave must be a wavewise pointer.
588
587
// It is user's responsibility to make sure that is true.
589
588
template <typename T, index_t N>
590
589
__device__ typename vector_type_maker<T, N>::type::type
591
- amd_buffer_load_v2 (const T* p_src_wave,
592
- index_t src_thread_data_offset ,
593
- bool src_thread_data_valid ,
594
- index_t src_element_space )
590
+ amd_buffer_load_invalid_element_return_return_zero (const T* p_src_wave,
591
+ index_t src_thread_element_offset ,
592
+ bool src_thread_element_valid ,
593
+ index_t src_element_space_size )
595
594
{
596
595
const int32x4_t src_wave_buffer_resource =
597
- make_wave_buffer_resource (p_src_wave, src_element_space );
596
+ make_wave_buffer_resource (p_src_wave, src_element_space_size );
598
597
599
- index_t src_thread_addr_offset = src_thread_data_offset * sizeof (T);
598
+ index_t src_thread_addr_offset = src_thread_element_offset * sizeof (T);
599
+
600
+ using vector_t = typename vector_type_maker<T, N>::type::type;
601
+ using scalar_t = typename scalar_type<vector_t >::type;
600
602
601
- using vector_t = typename vector_type_maker<T, N>::type::type;
602
- using scalar_t = typename scalar_type<vector_t >::type;
603
603
constexpr index_t vector_size = scalar_type<vector_t >::vector_size;
604
604
605
605
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
606
- uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff ;
606
+ uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff ;
607
607
608
- return amd_buffer_load_impl_v2 <scalar_t , vector_size>(
608
+ return amd_buffer_load_impl <scalar_t , vector_size>(
609
609
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0 );
610
610
#else
611
- vector_t tmp = amd_buffer_load_impl_v2 <scalar_t , vector_size>(
611
+ vector_t tmp = amd_buffer_load_impl <scalar_t , vector_size>(
612
612
src_wave_buffer_resource, src_thread_addr_offset, 0 );
613
613
614
- return src_thread_data_valid ? tmp : vector_t (0 );
614
+ return src_thread_element_valid ? tmp : vector_t (0 );
615
615
#endif
616
616
}
617
617
618
+ // buffer_load requires:
619
+ // 1) p_src_wave must be in global memory space
620
+ // 2) p_src_wave must be a wavewise pointer.
621
+ // It is user's responsibility to make sure that is true.
622
+ template <typename T, index_t N>
623
+ __device__ typename vector_type_maker<T, N>::type::type
624
+ amd_buffer_load_invalid_element_return_customized_value (const T* p_src_wave,
625
+ index_t src_thread_element_offset,
626
+ bool src_thread_element_valid,
627
+ index_t src_element_space_size,
628
+ T customized_value)
629
+ {
630
+ const int32x4_t src_wave_buffer_resource =
631
+ make_wave_buffer_resource (p_src_wave, src_element_space_size);
632
+
633
+ index_t src_thread_addr_offset = src_thread_element_offset * sizeof (T);
634
+
635
+ using vector_t = typename vector_type_maker<T, N>::type::type;
636
+ using scalar_t = typename scalar_type<vector_t >::type;
637
+
638
+ constexpr index_t vector_size = scalar_type<vector_t >::vector_size;
639
+
640
+ vector_t tmp = amd_buffer_load_impl<scalar_t , vector_size>(
641
+ src_wave_buffer_resource, src_thread_addr_offset, 0 );
642
+
643
+ return src_thread_element_valid ? tmp : vector_t (customized_value);
644
+ }
645
+
618
646
// buffer_store requires:
619
647
// 1) p_dst_wave must be global memory
620
648
// 2) p_dst_wave to be a wavewise pointer.
621
649
// It is user's responsibility to make sure that is true.
622
650
template <typename T, index_t N>
623
- __device__ void
624
- amd_buffer_store_v2 (const typename vector_type_maker<T, N>::type::type src_thread_data,
625
- T* p_dst_wave,
626
- const index_t dst_thread_data_offset,
627
- const bool dst_thread_data_valid,
628
- const index_t dst_element_space)
651
+ __device__ void amd_buffer_store (const typename vector_type_maker<T, N>::type::type src_thread_data,
652
+ T* p_dst_wave,
653
+ const index_t dst_thread_element_offset,
654
+ const bool dst_thread_element_valid,
655
+ const index_t dst_element_space_size)
629
656
{
630
657
const int32x4_t dst_wave_buffer_resource =
631
- make_wave_buffer_resource (p_dst_wave, dst_element_space );
658
+ make_wave_buffer_resource (p_dst_wave, dst_element_space_size );
632
659
633
- index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof (T);
660
+ index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof (T);
634
661
635
662
using vector_t = typename vector_type_maker<T, N>::type::type;
636
663
using scalar_t = typename scalar_type<vector_t >::type;
637
664
constexpr index_t vector_size = scalar_type<vector_t >::vector_size;
638
665
639
666
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
640
- uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff ;
667
+ uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff ;
641
668
642
- amd_buffer_store_impl_v2 <scalar_t , vector_size>(
669
+ amd_buffer_store_impl <scalar_t , vector_size>(
643
670
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0 );
644
671
#else
645
- if (dst_thread_data_valid )
672
+ if (dst_thread_element_valid )
646
673
{
647
- amd_buffer_store_impl_v2 <scalar_t , vector_size>(
674
+ amd_buffer_store_impl <scalar_t , vector_size>(
648
675
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0 );
649
676
}
650
677
#endif
0 commit comments