Skip to content

Commit 08d5c02

Browse files
illsilinaska-0096dependabot[bot]Jing Zhangzjing14
authored
OCP FP8 support for gfx12. (pytorch#1710)
* (2/5) bilinear gemm pass, perf bug: skip a lds has lower performance than skip b lds * (3/5) batched gemm pass, perf bug: skip a lds has lower performance than skip b lds * (4/5) grouped conv pass * (5/5) attention pass, todo: debug lds perf bug * AIT Attention API refactor (pytorch#8) * sanity pass * sanity pass 2 * confirm significant performance regression. * turn on all instances * turn off instance format * Fix bug & tunning & format * DML meta, self_attn+cross_attn * sanity pass * remove useless flag * update tile and problem size used in AIT attention * bug fix in grouped conv supporting check * deprecate inline asm wmma * Bug fix: double lds skip * clang-format * Fix errors in 1. example, fmha 2. gridwise pipeline 3. deviceop, fmha, change some containers from vector to array * part2 of previous commit * clang format * API fix of gridwisegemmpipeline * separate array base and vector base attention tensor transformation * fix gemm * clang format * add gemm fp16 instances * Temp save * fpAintB kernel compile pass * Sanity pass. * Temp save * debug code enabled * Fp16AInt8B_GEMM sanity * MQA implementation * GQA-4 example * tempsave * Compile pass * New implementation of fp16Aint8B Gemm, Acheieve similar math throughput with native fp16 Gemm * Bump rocm-docs-core from 0.24.0 to 0.29.0 in /docs/sphinx Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.24.0 to 0.29.0. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/RadeonOpenCompute/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](ROCm/rocm-docs-core@v0.24.0...v0.29.0) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <[email protected]> * initial enablement of gfx950 * fix clang format * disable examples 31 and 41 int8 on gfx950 * initial navi4x enablement * remove extra endif * enabled dl_gemm * update s_barrier and s_waitcnt for gfx12 * fix the gfx12 assembly syntax * fixed block_sync_lds * add support for more dl kernels on navi4 * add wmma * format * Todo: fix gemm_bilinear_wmma instances compilation bug * Solve a bug when K1=16 * remove unnecessary changes * Remove tensor layout limitation to LDS usage in tesnor contraction * fixed block_sync_lds * merge navi3_ref * update self-attention and cross-attention * fix a typo of name * fixed layout * debugging * Add arch limiter for fp8 gemm * fixed wmma * enable fp8 gemm_xdl for all gfx9 targets * temporarily disable gemm_xdl_fp16_fp8 on MI100/200 * fix the cmake logic for gemm_xdl_fp16_fp8 * fixed c_output * re-enable the gemm_xdl_fp16_fp8 on MI100/200 * fixed gfx12 * fixed * fixed * seperate gfx12 blockwise_gemm * fixed * enable fwd conv on navi4x * enable gridwise * enabled gemm * fixed merge * remove empty example fold * fixed conflicts * some small changes * Update cmake-ck-dev.sh * Update cmake-ck-dev.sh * enabled other types * fixed register loads * test fa * enable gfx12 * clean up * enable some instances on gfx12 * add gfx1201 macro in amd_wmma header * fix clang format * enable batched_gemm_softmax_gemm_perm_wmma for gfx12 * disable instances with blocksize=256 in attention examples * debuggging * debug * fixed lds_enabled * debugging * Fix and add limit to skiplds feature * Enable skipLds feature and fix compilation bugs * add ck_tile definitions for gfx12 * fix clang format and test/wmma_op * updage instances cmake for gfx12 * disable the test_wmma_op on gfx12 * fix the builds for gfx950 * add gfx12 and gfx950 to default target list * clean-up cmake file * Initial introduction of OFP8 data types. * Renamed FP8 and BF8 tests into FP8_FNUZ and BF8_FNUZ. * Implementation of ConvertFP32Nearest in test_fp8_ocp. * Remove dependence on possibly undeclared alias. * Implement FP8OCP test for stochastic rounding mode. * Implement FP8OCP tests for half_t type conversions. * enable bf16 atomic add on gfx950 * Implement ConvertFP32Nearest test. * Implement ConvertFP32Stochastic test. * Implement ConvertFP16Nearest and ConvertFP16Stochastic tests. * Refactoring. Move FP8 definitions into a separate header file. * Enable easy switching between architectures. * Fix compilation error for gfx942 architecture. * only builf gfx950 branch for gfx950 target by default * Enable OCP build of example_gemm_xdl_fp8. * Fix formatting. * fix the build logic for gfx950 * Improve GEMM example verbosity. * Add constexpr where applicable. * fix the logic of enabling XDL and WMMA instances * Improve GEMM example verbosity. * Enable build of example_gemm_xdl_fp8_bf8 test. * Fix tests for gfx1101 architecture. * Build DPP examples only on gfx103 and gfx11 architectures. * Optionaly run either CPU or GPU verifications with GEMM examples. * Extend GeneratorTensor_Sequential to produce values of prescribed data types. * Add missing constructor. * Improve infrastructure for OFP8 data type support. * BUGFIX. Should not use FP8 as Compute/Accum data type. * Add custom target for grouped_convnd_bwd_weight tests. * Can build `tests` target on gfx950. * Bugfixes on gfx1101 architecture. * Fix dependencies. * Provide single point of truth for FP8 INF and NAN checks * Prevent instantiation of operators that are not supported by FP8 data types * Add FP8 type selection into client_axample CMakeLists.txt * Prevent sccache server from shutting down during build * Fix test success reporting logic * Change default verification method to CPU. GPU verification takes too much time to complete on the emulator. * Make sure all tests and examples are built for gfx950 * Facilitate testing of FP8 data types on the emulator * Introduce two new tensor generators * Enable instances built for gfx94 to be built on gfx950 * Verify 35_splitk_gemm on floating point numbers. splitk gemm appears to be losing precision VS reference implementation when FP numbers are involved. * Verify 04_gemm_add_add_fastgelu on floating point numbers * Verify 20_grouped_conv_bwd_weight on floating point numbers * Verify 38_grouped_conv_bwd_data_multiple_d on floating point numbers * Verify more tests on floating point data * Fix data types and improve testing verbocity. * Upgrade to NPI 573 build docker. * Skip on gemm_universal tests. The tests take too long to complete on the emulator. Need to see if it is possible to reduce the scope of the testing to just FP8 data types. * Fix gfx1101 build * Document test availability * Re-enable fp8 gemms for gfx94/95 * Cherry-pick GEMM Universal tests for FP8 data types * Cleanup * CK_USE_GFX94 has already been set on this branch * Address formatting issues and leftovers * Make fail/pass logic consistent within 01_gemm folder Removed multiple negations in fail/pass logic to propagate `true` as the success indicator. * Fix GPU verification reporting logic. * Update year in copyright notice. * Cleanup * Use `enum class` instead of `enum` * Remove set_property for FP8 tests * Narrowing the scope of PR to OCP FP8 enablement only * Add tests for OCP FP8 vector_type storage * Enable gemm kernel on all gfx9 architectures (pytorch#227) * clean-up * Implement `non_native_vector_base` with `ext_vector_type` array. (pytorch#232) * Enable support of 1, 2, 4, and 8-byte custom types in CK. * Fix pool tests for OCP FP8 data type * fix jenkins file * restore cron trigger --------- Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: aska-0096 <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jing Zhang <[email protected]> Co-authored-by: zjing14 <[email protected]> Co-authored-by: Jun Liu <[email protected]> Co-authored-by: Andriy Roshchenko <[email protected]> Co-authored-by: Andriy Roshchenko <[email protected]>
1 parent 50ee426 commit 08d5c02

File tree

55 files changed

+2509
-384
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+2509
-384
lines changed

CMakeLists.txt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,22 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx9")
185185
add_definitions(-DCK_USE_XDL)
186186
endif()
187187
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94")
188-
message("Enabling FP8 gemms in ckProfiler")
188+
message("Enabling FP8 gemms on native architectures")
189189
add_definitions(-DCK_USE_GFX94)
190190
endif()
191191
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
192192
message("Enabling WMMA instances")
193193
add_definitions(-DCK_USE_WMMA)
194194
endif()
195+
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
196+
add_definitions(-DCK_USE_OCP_FP8)
197+
set(CK_USE_OCP_FP8 "ON")
198+
endif()
199+
if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx94")
200+
add_definitions(-DCK_USE_FNUZ_FP8)
201+
set(CK_USE_FNUZ_FP8 "ON")
202+
endif()
203+
195204
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
196205
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
197206
add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH)

client_example/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ if (GPU_TARGETS)
5656
add_definitions(-DCK_USE_WMMA)
5757
set(CK_USE_WMMA "ON")
5858
endif()
59+
if (GPU_TARGETS MATCHES "gfx12")
60+
add_definitions(-DCK_USE_OCP_FP8)
61+
set(CK_USE_OCP_FP8 "ON")
62+
endif()
63+
if (GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx94")
64+
add_definitions(-DCK_USE_FNUZ_FP8)
65+
set(CK_USE_FNUZ_FP8 "ON")
66+
endif()
5967
else()
6068
add_definitions(-DCK_USE_WMMA -DCK_USE_XDL)
6169
set(CK_USE_XDL "ON")

example/01_gemm/common.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ struct ProblemSizeSplitK final
7676
struct ExecutionConfig final
7777
{
7878
// 0 - no verification, 1 - CPU, 2 - GPU, 3 - CPU + GPU
79-
int do_verification = 3;
79+
int do_verification = 1;
8080
int init_method = 2;
8181
bool time_kernel = false;
8282
};

example/01_gemm/run_gemm_example.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
143143
switch(config.init_method)
144144
{
145145
case 0:
146-
ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k);
147-
ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n);
146+
ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.f)}(a_m_k);
147+
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(1.f)}(b_k_n);
148148
break;
149149
case 1:
150150
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);

example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
186186
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
187187
for(int j = 0; j < NumDMatrices; ++j)
188188
{
189-
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
189+
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
190190
}
191191
break;
192192
default:
193-
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
194-
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
193+
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
194+
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
195195
for(int j = 0; j < NumDMatrices; ++j)
196196
{
197-
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
197+
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
198198
}
199199
}
200200
}

example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,15 +190,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
190190
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
191191
for(int j = 0; j < NumDs; ++j)
192192
{
193-
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
193+
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
194194
}
195195
break;
196196
default:
197-
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
198-
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
197+
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
198+
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
199199
for(int j = 0; j < NumDs; ++j)
200200
{
201-
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
201+
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
202202
}
203203
}
204204
}

example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
167167
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
168168
break;
169169
default:
170-
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
171-
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
170+
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
171+
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
172172
}
173173

174-
d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
174+
d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<D0DataType, 1>{});
175175
}
176176

177177
using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<1>;

example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
157157
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
158158
break;
159159
default:
160-
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
161-
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
160+
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
161+
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
162162
}
163163
}
164164

example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
158158
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
159159
break;
160160
default:
161-
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
162-
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
161+
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
162+
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
163163
}
164164
}
165165

example/15_grouped_gemm/run_grouped_gemm_example.inc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3+
14
#pragma once
25

36
struct ProblemSize final
@@ -124,8 +127,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
124127
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
125128
break;
126129
default:
127-
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
128-
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
130+
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
131+
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
129132
}
130133
}
131134

example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
33

44
#include <iostream>
55
#include <numeric>
@@ -175,8 +175,8 @@ int main(int argc, char* argv[])
175175
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
176176
break;
177177
default:
178-
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
179-
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
178+
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
179+
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
180180
}
181181

182182
c0_n_bias.GenerateTensorValue(GeneratorTensor_2<C0DataType>{-5, 5});

example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
33

44
#pragma once
55

@@ -150,7 +150,7 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[])
150150
break;
151151
default:
152152
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
153-
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
153+
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
154154
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
155155
}
156156

example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
33

44
int run(int argc, char* argv[])
55
{
@@ -157,7 +157,7 @@ int run(int argc, char* argv[])
157157
break;
158158
default:
159159
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
160-
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
160+
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
161161
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
162162
}
163163

example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
33

44
int run(int argc, char* argv[])
55
{
@@ -118,7 +118,7 @@ int run(int argc, char* argv[])
118118
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
119119
break;
120120
default:
121-
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
121+
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
122122
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
123123
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
124124
}

example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
33

44
int run(int argc, char* argv[])
55
{
@@ -153,7 +153,7 @@ int run(int argc, char* argv[])
153153
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
154154
break;
155155
default:
156-
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
156+
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
157157
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
158158
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
159159
}

example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
33

44
int run(int argc, char* argv[])
55
{
@@ -178,7 +178,7 @@ int run(int argc, char* argv[])
178178
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
179179
break;
180180
default:
181-
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
181+
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
182182
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
183183
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
184184
}

example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
33

44
int run(int argc, char* argv[])
55
{
@@ -152,7 +152,7 @@ int run(int argc, char* argv[])
152152
break;
153153
default:
154154
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
155-
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
155+
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
156156
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
157157
}
158158

example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
33

44
int run(int argc, char* argv[])
55
{
@@ -156,7 +156,7 @@ int run(int argc, char* argv[])
156156
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
157157
break;
158158
default:
159-
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
159+
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
160160
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
161161
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
162162
}

example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
33

44
int run(int argc, char* argv[])
55
{
@@ -156,7 +156,7 @@ int run(int argc, char* argv[])
156156
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
157157
break;
158158
default:
159-
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
159+
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
160160
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
161161
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
162162
}

example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
33

44
int run(int argc, char* argv[])
55
{
@@ -173,7 +173,7 @@ int run(int argc, char* argv[])
173173
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
174174
break;
175175
default:
176-
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
176+
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
177177
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
178178
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
179179
}

example/35_splitK_gemm/run_splitK_gemm_example.inc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3+
14
#pragma once
25

36
struct ProblemSize final
@@ -66,8 +69,8 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
6669
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
6770
break;
6871
default:
69-
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
70-
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
72+
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
73+
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
7174
}
7275

7376
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());

example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ int main(int argc, char* argv[])
377377
break;
378378
default:
379379
a0_g_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1});
380-
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
380+
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
381381
d00_g_m_n.GenerateTensorValue(GeneratorTensor_1<D00DataType>{1});
382382
d01_g_m_n.GenerateTensorValue(GeneratorTensor_1<D01DataType>{1});
383383
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});

example/38_grouped_conv_bwd_data_multiple_d/common.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
33

44
#pragma once
55

@@ -41,7 +41,7 @@ struct ExecutionConfig final
4141
{
4242
bool do_verification = true;
4343
int init_method = 1;
44-
bool time_kernel = true;
44+
bool time_kernel = false;
4545
};
4646

4747
#define DefaultConvParams \

example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
33

44
#include <iostream>
55
#include <vector>
@@ -248,7 +248,7 @@ int main(int argc, char* argv[])
248248
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
249249
break;
250250
default:
251-
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
251+
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
252252
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
253253
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
254254
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});

example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
33

44
#include <iostream>
55
#include <numeric>
@@ -194,9 +194,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
194194
b1_tensors[i].GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
195195
break;
196196
default:
197-
a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
198-
b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
199-
b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
197+
a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<A0DataType, 0>{});
198+
b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
199+
b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<B1DataType, 1>{});
200200
}
201201

202202
d0_tensors[i].GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});

0 commit comments

Comments
 (0)