Skip to content

Commit 566b648

Browse files
authored
Code clean-up (pytorch#1285)
* code clean-up * remove the profiling output samples
1 parent fcba889 commit 566b648

38 files changed

+57
-259
lines changed

CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,18 +202,18 @@ endif()
202202

203203

204204
option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
205-
option(USE_OPT_NAVI3X "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF)
205+
option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF)
206206

207207
if(USE_BITINT_EXTENSION_INT4)
208208
add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
209209
add_compile_options(-Wno-bit-int-extension)
210210
message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
211211
endif()
212212

213-
if(USE_OPT_NAVI3X)
213+
if(USE_OPT_GFX11)
214214
add_compile_options(-mcumode)
215215
add_compile_options(-mno-wavefrontsize64)
216-
message("CK compiled with USE_OPT_NAVI3X set to ${USE_OPT_NAVI3X}")
216+
message("CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}")
217217
endif()
218218

219219
## Threads

Jenkinsfile

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -515,38 +515,33 @@ def Build_CK(Map conf=[:]){
515515
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
516516
timeout(time: 24, unit: 'HOURS')
517517
{
518-
//check whether running on Navi or MI300 node
519-
def navi_node = 0
520-
def mi300_node = 0
518+
//check whether to run performance tests on this node
519+
def do_perf_tests = 0
521520
sh 'rocminfo | tee rocminfo.log'
522-
if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') ){
523-
navi_node = 1
524-
echo "This is a Navi node"
525-
}
526-
if ( runShell('grep -n "gfx942" rocminfo.log') ){
527-
mi300_node = 1
528-
echo "This is MI300 node"
521+
if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') || runShell('grep -n "gfx942" rocminfo.log') ){
522+
do_perf_tests = 1
523+
echo "Stash profiler and run performance tests"
529524
}
530525
cmake_build(conf)
531526
dir("build"){
532527
//run tests and examples
533528
sh 'make -j check'
534-
if (params.RUN_PERFORMANCE_TESTS && navi_node == 0 && mi300_node == 0 ){
529+
if (params.RUN_PERFORMANCE_TESTS && do_perf_tests == 0 ){
535530
//we only need the ckProfiler to run the performance tests, so we pack and stash it
536-
//do not stash profiler on Navi or MI300 nodes
531+
//do not stash profiler on nodes where we don't need to run performance tests
537532
sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler'
538533
stash name: "ckProfiler.tar.gz"
539534
}
540-
if (params.RUN_FULL_QA && mi300_node == 0 ){
541-
// build deb packages for all MI100/200/300 targets and prepare to export
535+
if (params.RUN_FULL_QA && do_perf_tests == 0 ){
536+
// build deb packages for all gfx9 targets and prepare to export
542537
sh 'make -j package'
543538
archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb'
544539
archiveArtifacts artifacts: 'composablekernel-tests_*.deb'
545540
sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb'
546541
stash name: "ckprofiler_0.2.0_amd64.deb"
547542
}
548543
}
549-
if (params.hipTensor_test && navi_node == 0 ){
544+
if (params.hipTensor_test && do_perf_tests == 0 ){
550545
//build and test hipTensor
551546
sh """#!/bin/bash
552547
rm -rf "${params.hipTensor_branch}".zip
@@ -814,7 +809,7 @@ pipeline {
814809
{
815810
parallel
816811
{
817-
stage("Run Codegen Tests on MI200")
812+
stage("Run Codegen Tests on gfx90a")
818813
{
819814
when {
820815
beforeAgent true
@@ -865,7 +860,7 @@ pipeline {
865860
cleanWs()
866861
}
867862
}
868-
stage("Build CK and run Tests on MI300")
863+
stage("Build CK and run Tests on gfx942")
869864
{
870865
when {
871866
beforeAgent true
@@ -885,7 +880,7 @@ pipeline {
885880
cleanWs()
886881
}
887882
}
888-
stage("Build CK and run Tests on MI200")
883+
stage("Build CK and run Tests on gfx90a")
889884
{
890885
when {
891886
beforeAgent true
@@ -925,13 +920,13 @@ pipeline {
925920
cleanWs()
926921
}
927922
}
928-
stage("Build CK and run Tests on Navi21")
923+
stage("Build CK and run Tests on gfx1030")
929924
{
930925
when {
931926
beforeAgent true
932927
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
933928
}
934-
agent{ label rocmnode("navi21") }
929+
agent{ label rocmnode("gfx1030") }
935930
environment{
936931
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """
937932
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
@@ -945,13 +940,13 @@ pipeline {
945940
cleanWs()
946941
}
947942
}
948-
stage("Build CK and run Tests on Navi32")
943+
stage("Build CK and run Tests on gfx1101")
949944
{
950945
when {
951946
beforeAgent true
952947
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
953948
}
954-
agent{ label rocmnode("navi32") }
949+
agent{ label rocmnode("gfx1101") }
955950
environment{
956951
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """
957952
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \

client_example/25_wrapper/wrapper_img2col.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,4 +181,3 @@ int main(int argc, char* argv[])
181181
{1, 1, 1} /*filter_dilations*/);
182182
return 0;
183183
}
184-
// MI100 Perf: 0.255178 ms, 1698.9 GB/s,

example/01_gemm/README.md

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,3 @@
77
#arg3: run kernel # of times (>1)
88
./bin/example_gemm_xdl 0 1 5
99
```
10-
11-
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
12-
```
13-
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
14-
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
15-
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
16-
arg.a_grid_desc_k0_m_k1_{512, 3840, 8}
17-
arg.b_grid_desc_k0_n_k1_{512, 4096, 8}
18-
arg.c_grid_desc_m_n_{ 3840, 4096}
19-
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
20-
Warm up
21-
Start running 5 times...
22-
Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s
23-
```

example/02_gemm_bilinear/README.md

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,3 @@
99
#arg11 to 12: alpha, beta
1010
./bin/example_gemm_bilinear_xdl_fp16 1 1 1 3840 4096 4096 4096 4096 4096 4096 0.5 0.5
1111
```
12-
Result (MI100 @ 1502Mhz, 184.6TFlops peak FP16)
13-
```
14-
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
15-
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
16-
c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
17-
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
18-
arg.a_grid_desc_k0_m_k1_{512, 3840, 8}
19-
arg.b_grid_desc_k0_n_k1_{512, 4096, 8}
20-
arg.c0_grid_desc_m_n_{ 3840, 4096}
21-
arg.c_grid_desc_m_n_{ 3840, 4096}
22-
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
23-
Warm up
24-
Start running 1 times...
25-
Perf: 0.936965 ms, 137.517 TFlops, 102.959 GB/s
26-
error: 0
27-
max_diff: 0, 558.5, 558.5
28-
```

example/04_gemm_add_add_fastgelu/README.md

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,3 @@
88
#arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE"
99
./bin/example_gemm_add_add_fastgelu_xdl_fp16 1 1 1
1010
```
11-
12-
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
13-
```
14-
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
15-
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
16-
d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1}
17-
d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
18-
e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
19-
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
20-
Warm up 1 time
21-
Start running 10 times...
22-
Perf: 1.26914 ms, 101.525 TFlops, 100.804 GB/s, DeviceGemmMultipleD_Xdl_CShuffle<256, 256, 128, 32, 8, 8>
23-
```

example/09_convnd_fwd/README.md

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,3 @@
1616
# <right padding>, (ie RightPy, RightPx for 2D)
1717
./bin/example_convnd_fwd_xdl 0 1 100
1818
```
19-
20-
Result (MI100 @ 1087Mhz, 33.4TFlops peak FP32)
21-
```
22-
input: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192}
23-
weights: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192}
24-
output: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256}
25-
arg.a_grid_desc_k0_m_k1_{432, 165888, 4}
26-
arg.b_grid_desc_k0_n_k1_{432, 256, 4}
27-
arg.c_grid_desc_m_n_{ 165888, 256}
28-
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
29-
Warm up
30-
Start running 100 times...
31-
Perf: 4.43736 ms, 33.0753 TFlops, 150.357 GB/s
32-
```

example/15_grouped_gemm/README.md

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,3 @@
77
#arg3: run kernel # of times (>1)
88
./bin/example_grouped_gemm_xdl_fp16 0 1 5
99
```
10-
11-
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
12-
```
13-
gemm[0] a_m_k: dim 2, lengths {256, 64}, strides {64, 1} b_k_n: dim 2, lengths {64, 128}, strides {1, 64} c_m_n: dim 2, lengths {256, 128}, strides {128, 1}
14-
gemm[1] a_m_k: dim 2, lengths {512, 128}, strides {128, 1} b_k_n: dim 2, lengths {128, 256}, strides {1, 128} c_m_n: dim 2, lengths {512, 256}, strides {256, 1}
15-
gemm[2] a_m_k: dim 2, lengths {768, 192}, strides {192, 1} b_k_n: dim 2, lengths {192, 384}, strides {1, 192} c_m_n: dim 2, lengths {768, 384}, strides {384, 1}
16-
gemm[3] a_m_k: dim 2, lengths {1024, 256}, strides {256, 1} b_k_n: dim 2, lengths {256, 512}, strides {1, 256} c_m_n: dim 2, lengths {1024, 512}, strides {512, 1}
17-
group: 0 arg.a_grid_desc_k0_m_k1_{8, 256, 8}, arg.b_grid_desc_k0_n_k1_{8, 128, 8}, arg.c_grid_desc_m_n_{ 256, 128}
18-
group: 1 arg.a_grid_desc_k0_m_k1_{16, 512, 8}, arg.b_grid_desc_k0_n_k1_{16, 256, 8}, arg.c_grid_desc_m_n_{ 512, 256}
19-
group: 2 arg.a_grid_desc_k0_m_k1_{24, 768, 8}, arg.b_grid_desc_k0_n_k1_{24, 384, 8}, arg.c_grid_desc_m_n_{ 768, 384}
20-
group: 3 arg.a_grid_desc_k0_m_k1_{32, 1024, 8}, arg.b_grid_desc_k0_n_k1_{32, 512, 8}, arg.c_grid_desc_m_n_{ 1024, 512}
21-
launch_and_time_kernel: grid_dim {30, 1, 1}, block_dim {256, 1, 1}
22-
Warm up
23-
Start running 5 times...
24-
Perf: 0.037887 ms, 11.0706 TFlops, 90.8132 GB/s, DeviceGroupedGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2>
25-
```

example/26_contraction/README.md

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,3 @@
77
#arg3: time kernel (0=no, 1=yes)
88
./bin/example_contraction_bilinear_xdl_fp32 1 1 1
99
```
10-
11-
Result (MI100 @ dynammic freq, 46TFlops peak FP32)
12-
```
13-
a_ms_ks: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1}
14-
b_ks_ns: dim 4, lengths {32, 64, 32, 64}, strides {128, 1, 524288, 4096}
15-
c_ms_ns: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1}
16-
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
17-
Warm up 1 time
18-
Start running 10 times...
19-
Perf: 0.843286 ms, 38.1985 TFlops, 94.5014 GB/s, DeviceContractionMultipleD_Xdl_CShuffle<256, 256, 128, 16, 4, 4>
20-
```

example/30_grouped_conv_fwd_multiple_d/README.md

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,3 @@ Following arguments (depending on number of spatial dims):
1616
./bin/example_grouped_conv_fwd_bias_relu_add_xdl_fp16 1 1 1
1717
```
1818

19-
Result (MI100)
20-
```
21-
in: dim 5, lengths {1, 128, 192, 71, 71}, strides {192, 967872, 1, 13632, 192}
22-
wei: dim 5, lengths {1, 256, 192, 3, 3}, strides {442368, 1728, 1, 576, 192}
23-
bias: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 0, 1, 0, 0}
24-
residual: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 0, 1, 0, 0}
25-
out: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 331776, 1, 9216, 256}
26-
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
27-
Warm up 1 time
28-
Start running 10 times...
29-
Perf: 1.55981 ms, 94.0927 TFlops, 213.868 GB/s, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 16, Default>
30-
```

example/46_gemm_add_multiply/README.md

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,3 @@
88
#arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE"
99
./bin/example_gemm_add_multiply_dl_fp16 1 1 1
1010
```
11-
12-
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
13-
```
14-
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
15-
b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
16-
d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1}
17-
d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
18-
e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
19-
arg.a_grid_desc_k0_m0_m1_k1_{2048, 3840, 2}
20-
arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2}
21-
arg.e_grid_desc_m_n_{ 3840, 4096}
22-
launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1}
23-
Warm up 1 time
24-
Start running 10 times...
25-
Perf: 3.99904 ms, 32.22 TFlops, 31.9913 GB/s, DeviceGemmMultipleD_Dl<256, 128, 128, 16, 2, 4, 4, 1>
26-
```

include/ck/ck.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
236236
#ifndef CK_WORKAROUND_DENORM_FIX
237237
#define CK_WORKAROUND_DENORM_FIX 0
238238
#else
239-
// enable only on MI200
239+
// enable only for gfx90a
240240
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
241241
#endif // CK_WORKAROUND_DENORM_FIX
242242

include/ck/host_utility/device_prop.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,20 @@ inline bool is_lds_direct_load_supported()
6565
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
6666
}
6767

68-
inline bool is_navi1_supported()
68+
inline bool is_gfx101_supported()
6969
{
7070
return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" ||
7171
ck::get_device_name() == "gfx1012";
7272
}
7373

74-
inline bool is_navi2_supported()
74+
inline bool is_gfx103_supported()
7575
{
7676
return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" ||
7777
ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" ||
7878
ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036";
7979
}
8080

81-
inline bool is_navi3_supported()
81+
inline bool is_gfx11_supported()
8282
{
8383
return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
8484
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";

include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
829829

830830
static bool IsSupportedArgument(const Argument& arg)
831831
{
832-
if(ck::is_navi3_supported())
832+
if(ck::is_gfx11_supported())
833833
{
834834
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
835835
{

include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
648648
static bool IsSupportedArgument(const Argument& arg)
649649
{
650650
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
651-
ck::is_navi2_supported() || ck::is_navi3_supported())
651+
ck::is_gfx103_supported() || ck::is_gfx11_supported())
652652
{
653653
bool pass = true;
654654
pass = pass && arg.K_ % K1 == 0;

include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
858858

859859
static bool IsSupportedArgument(const RawArg& arg)
860860
{
861-
if(ck::is_navi3_supported())
861+
if(ck::is_gfx11_supported())
862862
{
863863
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
864864
{
@@ -1435,7 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
14351435
#if 0
14361436
static bool IsSupportedArgument(const Argument& arg)
14371437
{
1438-
if(ck::is_navi3_supported())
1438+
if(ck::is_gfx11_supported())
14391439
{
14401440
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
14411441
{

include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,8 +1392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
13921392
static bool IsSupportedArgument(const Argument& arg)
13931393
{
13941394
// check device
1395-
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
1396-
ck::is_navi3_supported()))
1395+
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
1396+
ck::is_gfx11_supported()))
13971397
{
13981398
return false;
13991399
}

include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
509509

510510
static bool IsSupportedArgument(const Argument& arg)
511511
{
512-
if(ck::is_navi3_supported())
512+
if(ck::is_gfx11_supported())
513513
{
514514
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
515515
is_same_v<AccDataType, int32_t>))

include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,8 +535,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
535535
}
536536
}
537537

538-
if(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
539-
ck::is_navi3_supported())
538+
if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
539+
ck::is_gfx11_supported())
540540
{
541541
return GridwiseGemm::CheckValidity(
542542
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);

include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ struct DeviceGemmDpp : public DeviceGemm<ALayout,
168168

169169
static bool IsSupportedArgument(const Argument& karg)
170170
{
171-
if(ck::is_navi2_supported() || ck::is_navi3_supported())
171+
if(ck::is_gfx103_supported() || ck::is_gfx11_supported())
172172
{
173173
return GridwiseGemm::CheckValidity(karg);
174174
}

include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
552552
static bool IsSupportedArgument(const Argument& arg)
553553
{
554554
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
555-
ck::is_navi2_supported() || ck::is_navi3_supported())
555+
ck::is_gfx103_supported() || ck::is_gfx11_supported())
556556
{
557557
return GridwiseGemm::CheckValidity(
558558
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);

0 commit comments

Comments
 (0)