Skip to content

Commit 9f71ff4

Browse files
rosenrodtChao Liu
and
Chao Liu
authored
Validate examples in CI (pytorch#233)
* validate examples in ctest runs * format * fix usage of check_err * amend * add example codes to custom target 'check' Co-authored-by: Chao Liu <[email protected]>
1 parent cec69bc commit 9f71ff4

26 files changed

+125
-62
lines changed

CMakeLists.txt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ if(BUILD_DEV)
245245
endif()
246246
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
247247

248+
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
249+
248250
add_subdirectory(library)
249251
add_subdirectory(example)
250252
add_subdirectory(test)
@@ -260,14 +262,14 @@ write_basic_package_version_file(
260262
COMPATIBILITY AnyNewerVersion
261263
)
262264

263-
configure_package_config_file(${CMAKE_CURRENT_SOURCE_DIR}/Config.cmake.in
265+
configure_package_config_file(${CMAKE_CURRENT_SOURCE_DIR}/Config.cmake.in
264266
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake"
265-
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
267+
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
266268
NO_CHECK_REQUIRED_COMPONENTS_MACRO
267269
)
268270

269-
install(FILES
271+
install(FILES
270272
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake"
271273
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake"
272-
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
274+
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
273275
)

example/01_gemm/gemm_xdl_bf16.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ int main(int argc, char* argv[])
232232

233233
ref_invoker.Run(ref_argument);
234234

235-
ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData);
235+
return ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData) ? 0 : 1;
236236
}
237237

238238
return 0;

example/01_gemm/gemm_xdl_fp16.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ int main(int argc, char* argv[])
196196

197197
ref_invoker.Run(ref_argument);
198198

199-
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
199+
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
200200
}
201201

202202
return 0;

example/01_gemm/gemm_xdl_int8.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ int main(int argc, char* argv[])
219219

220220
ref_invoker.Run(ref_argument);
221221

222-
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
222+
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
223223
}
224224

225225
return 0;

example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,8 @@ int main(int argc, char* argv[])
246246

247247
ref_invoker.Run(ref_argument);
248248

249-
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
249+
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
250250
}
251+
252+
return 0;
251253
}

example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ int main(int argc, char* argv[])
232232

233233
ref_invoker.Run(ref_argument);
234234

235-
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
235+
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
236236
}
237+
238+
return 0;
237239
}

example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,8 @@ int main(int argc, char* argv[])
250250

251251
ref_invoker.Run(ref_argument);
252252

253-
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
253+
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
254254
}
255+
256+
return 0;
255257
}

example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,8 @@ int main(int argc, char* argv[])
305305
OutElementOp{});
306306
ref_invoker.Run(ref_argument);
307307
out_device_buf.FromDevice(device_output.mData.data());
308-
ck::utils::check_err(
309-
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
308+
return ck::utils::check_err(device_output.mData, host_output.mData) ? 0 : 1;
310309
}
310+
311+
return 0;
311312
}
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
add_example_executable(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp)
1+
# FIXME: should fix validation failure
2+
add_example_executable_no_testing(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp)
23
target_link_libraries(example_conv2d_fwd_xdl_bias_relu_add PRIVATE conv_util)

example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,8 @@ int main(int argc, char* argv[])
320320

321321
ref_invoker.Run(ref_argument);
322322
out_device_buf.FromDevice(device_output.mData.data());
323-
ck::utils::check_err(
324-
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
323+
return ck::utils::check_err(device_output.mData, host_output.mData) ? 0 : 1;
325324
}
325+
326+
return 0;
326327
}

example/09_convnd_fwd/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
add_example_executable(example_convnd_fwd_xdl convnd_fwd_xdl.cpp)
2-
target_link_libraries(example_convnd_fwd_xdl PRIVATE conv_util)
1+
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
32
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
4-
target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_util)
53
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
4+
target_link_libraries(example_convnd_fwd_xdl_fp32 PRIVATE conv_util)
5+
target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_util)
66
target_link_libraries(example_convnd_fwd_xdl_fp16 PRIVATE conv_util)

example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ template <ck::index_t NumDimSpatial>
4343
using DeviceConvNDFwdInstance = ck::tensor_operation::device::
4444
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
4545
// clang-format off
46-
InDataType, //
46+
InDataType, //
4747
WeiDataType, //
4848
OutDataType, //
49-
AccDataType, //
49+
AccDataType, //
5050
InElementOp, // Input Elementwise Operation
5151
WeiElementOp, // Weights Elementwise Operation
5252
OutElementOp, // Output Elementwise Operation
@@ -312,8 +312,8 @@ int main(int argc, char* argv[])
312312

313313
ref_invoker.Run(ref_argument);
314314
out_device_buf.FromDevice(device_output.mData.data());
315-
ck::utils::check_err(
316-
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
315+
return ck::utils::check_err(
316+
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f) ? 0 : 1;
317317
};
318318

319319
switch(num_dim_spatial)
@@ -338,4 +338,5 @@ int main(int argc, char* argv[])
338338
}
339339
}
340340
}
341+
return 0;
341342
}

example/09_convnd_fwd/convnd_fwd_xdl.cpp renamed to example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ template <ck::index_t NumDimSpatial>
3939
using DeviceConvNDFwdInstance = ck::tensor_operation::device::
4040
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
4141
// clang-format off
42-
InDataType, //
42+
InDataType, //
4343
WeiDataType, //
4444
OutDataType, //
45-
AccDataType, //
45+
AccDataType, //
4646
InElementOp, // Input Elementwise Operation
4747
WeiElementOp, // Weights Elementwise Operation
4848
OutElementOp, // Output Elementwise Operation
@@ -311,8 +311,13 @@ int main(int argc, char* argv[])
311311

312312
ref_invoker.Run(ref_argument);
313313
out_device_buf.FromDevice(device_output.mData.data());
314-
ck::utils::check_err(
315-
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
314+
return ck::utils::check_err(device_output.mData,
315+
host_output.mData,
316+
"Error: incorrect results!",
317+
1e-5f,
318+
1e-4f)
319+
? 0
320+
: 1;
316321
};
317322

318323
switch(num_dim_spatial)
@@ -337,4 +342,5 @@ int main(int argc, char* argv[])
337342
}
338343
}
339344
}
345+
return 0;
340346
}

example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ template <ck::index_t NumDimSpatial>
4545
using DeviceConvNDFwdInstance = ck::tensor_operation::device::
4646
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
4747
// clang-format off
48-
InDataType, //
48+
InDataType, //
4949
WeiDataType, //
5050
OutDataType, //
51-
AccDataType, //
51+
AccDataType, //
5252
InElementOp, // Input Elementwise Operation
5353
WeiElementOp, // Weights Elementwise Operation
5454
OutElementOp, // Output Elementwise Operation
@@ -314,8 +314,8 @@ int main(int argc, char* argv[])
314314

315315
ref_invoker.Run(ref_argument);
316316
out_device_buf.FromDevice(device_output.mData.data());
317-
ck::utils::check_err(
318-
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
317+
return ck::utils::check_err(
318+
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f) ? 0 : 1;
319319
};
320320

321321
switch(num_dim_spatial)
@@ -340,4 +340,5 @@ int main(int argc, char* argv[])
340340
}
341341
}
342342
}
343+
return 0;
343344
}

example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,10 @@ int main(int argc, char* argv[])
249249

250250
in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data());
251251

252-
ck::utils::check_err(in_n_c_hi_wi_device_result.mData, in_n_c_hi_wi_host_result.mData);
252+
return ck::utils::check_err(in_n_c_hi_wi_device_result.mData,
253+
in_n_c_hi_wi_host_result.mData)
254+
? 0
255+
: 1;
253256
}
257+
return 0;
254258
}

example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,9 @@ int main(int argc, char* argv[])
291291
LogRangeAsType<float>(std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",")
292292
<< std::endl;
293293
}
294-
ck::utils::check_err(wei_k_c_y_x_device_result.mData, wei_k_c_y_x_host_result.mData);
294+
return ck::utils::check_err(wei_k_c_y_x_device_result.mData, wei_k_c_y_x_host_result.mData)
295+
? 0
296+
: 1;
295297
}
298+
return 0;
296299
}

example/12_reduce/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
add_example_executable(example_reduce_blockwise reduce_blockwise.cpp)
1+
add_example_executable(example_reduce_blockwise reduce_blockwise.cpp -D 16,64,32,960 -v 1 1 10)

example/12_reduce/reduce_blockwise.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,16 +361,17 @@ int main(int argc, char* argv[])
361361
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << reduce_name
362362
<< std::endl;
363363

364+
bool pass = true;
364365
if(args.do_verification)
365366
{
366367
out_dev.FromDevice(out.mData.data());
367-
ck::utils::check_err(out.mData, out_ref.mData);
368+
pass &= ck::utils::check_err(out.mData, out_ref.mData);
368369

369370
if(NeedIndices)
370371
{
371372
out_indices_dev.FromDevice(out_indices.mData.data());
372-
ck::utils::check_err(out_indices.mData, out_indices_ref.mData);
373-
;
373+
pass &= ck::utils::check_err(out_indices.mData, out_indices_ref.mData);
374374
};
375375
};
376+
return pass ? 0 : 1;
376377
}

example/13_pool2d_fwd/pool2d_fwd.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ int main(int argc, char* argv[])
285285
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
286286
<< std::endl;
287287

288+
bool pass = true;
288289
if(do_verification)
289290
{
290291
pool_host_verify<InDataType,
@@ -302,14 +303,15 @@ int main(int argc, char* argv[])
302303

303304
out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data());
304305

305-
ck::utils::check_err(out_n_c_ho_wo_device.mData, out_n_c_ho_wo_host.mData);
306+
pass &= ck::utils::check_err(out_n_c_ho_wo_device.mData, out_n_c_ho_wo_host.mData);
306307

307308
if constexpr(NeedIndices)
308309
{
309310
out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data());
310311

311-
// ck::utils::check_err(out_indices_n_c_ho_wo_device.mData,
312-
// out_indices_n_c_ho_wo_host.mData);;
312+
pass &= ck::utils::check_err(out_indices_n_c_ho_wo_device.mData,
313+
out_indices_n_c_ho_wo_host.mData);
313314
};
314315
}
316+
return pass ? 0 : 1;
315317
}

example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ int main(int argc, char* argv[])
244244

245245
ref_invoker.Run(ref_argument);
246246

247-
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
247+
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
248248
}
249249

250250
return 0;

example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ int main(int argc, char* argv[])
211211
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
212212
<< gemm.GetTypeString() << std::endl;
213213

214+
bool pass = true;
214215
if(do_verification)
215216
{
216217
for(std::size_t i = 0; i < gemm_shapes.size(); i++)
@@ -227,9 +228,9 @@ int main(int argc, char* argv[])
227228
c_element_op);
228229

229230
ref_invoker.Run(ref_argument);
230-
ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData);
231+
pass &= ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData);
231232
}
232233
}
233234

234-
return 0;
235+
return pass ? 0 : 1;
235236
}

example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <cstdlib>
55
#include <stdlib.h>
66
#include <half.hpp>
7+
#include "check_err.hpp"
78
#include "config.hpp"
89
#include "device.hpp"
910
#include "host_tensor.hpp"
@@ -211,6 +212,7 @@ int main(int argc, char* argv[])
211212
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
212213
<< gemm.GetTypeString() << std::endl;
213214

215+
bool pass = true;
214216
if(do_verification)
215217
{
216218
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
@@ -247,10 +249,19 @@ int main(int argc, char* argv[])
247249
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc);
248250
}
249251

250-
check_error(c_m_n_host_result, c_m_n_device_result);
251-
check_error(d0_m_host_result, d0_m_device_result);
252-
check_error(d1_m_host_result, d1_m_device_result);
252+
pass &= ck::utils::check_err(
253+
c_m_n_device_result.mData, c_m_n_host_result.mData, "Error: Incorrect results c");
254+
pass &= ck::utils::check_err(d0_m_device_result.mData,
255+
d0_m_host_result.mData,
256+
"Error: Incorrect results d0",
257+
1e-3,
258+
1e-3);
259+
pass &= ck::utils::check_err(d1_m_device_result.mData,
260+
d1_m_host_result.mData,
261+
"Error: Incorrect results d1",
262+
1e-3,
263+
1e-3);
253264
}
254265

255-
return 0;
266+
return pass ? 0 : 1;
256267
}

example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,10 @@ int main(int argc, char* argv[])
322322

323323
in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data());
324324

325-
check_error(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result);
325+
return ck::utils::check_err(in_n_c_hi_wi_device_result.mData,
326+
in_n_c_hi_wi_host_result.mData)
327+
? 0
328+
: 1;
326329
};
327330

328331
switch(num_dim_spatial)
@@ -347,4 +350,5 @@ int main(int argc, char* argv[])
347350
}
348351
}
349352
}
353+
return 0;
350354
}

0 commit comments

Comments
 (0)