39
39
40
40
@dataclass
41
41
class VkTestSuite (TestSuite ):
42
- supports = {
43
- "storage_types" : ["api::StorageType::TEXTURE_3D" ],
44
- "layouts" : [
45
- "api::GPUMemoryLayout::TENSOR_WIDTH_PACKED" ,
46
- "api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED" ,
47
- ],
48
- }
42
+ def __init__ (self , input_cases : List [Any ]):
43
+ super ().__init__ (input_cases )
44
+ self .storage_types : List [str ] = ["api::kTexture3D" ]
45
+ self .layouts : List [str ] = ["api::kChannelsPacked" ]
49
46
50
47
51
48
##########################
@@ -88,7 +85,6 @@ def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
88
85
self .dot = "->"
89
86
90
87
self .args = []
91
- self .out = None
92
88
self .refs = {}
93
89
94
90
self .should_prepack = False
@@ -288,6 +284,7 @@ def set_output(self, ref: ValueRefList) -> str:
288
284
return ret_str
289
285
290
286
def virtual_resize (self , ref : ValueRefList ) -> str :
287
+ assert isinstance (ref , ValueRef )
291
288
assert ref .src_cpp_type == AT_TENSOR and ref .is_in
292
289
if self .prepack_ref (ref ):
293
290
return ""
@@ -296,6 +293,7 @@ def virtual_resize(self, ref: ValueRefList) -> str:
296
293
return ret_str
297
294
298
295
def copy_into_staging (self , ref : ValueRefList ) -> str :
296
+ assert isinstance (ref , ValueRef )
299
297
assert ref .src_cpp_type == AT_TENSOR and ref .is_in
300
298
if self .prepack_ref (ref ):
301
299
return ""
@@ -336,7 +334,7 @@ def check_graph_out(self, ref: ValueRefList) -> str:
336
334
ret_str += self .check_graph_out (r )
337
335
return ret_str
338
336
339
- return f"EXPECT_TRUE(check_close({ ref .src_cpp_name } , vk_{ ref .name } ));\n "
337
+ return f"EXPECT_TRUE(check_close({ ref .src_cpp_name } , vk_{ ref .name } , rtol, atol ));\n "
340
338
341
339
## Top level code generation
342
340
@@ -374,11 +372,19 @@ def gen_graph_exec_code(self) -> str:
374
372
375
373
return graph_exec
376
374
375
+ def gen_conditional_skips (self ) -> str :
376
+ skips = "if (test_dtype == at::kHalf && "
377
+ skips += f"!{ self .graph } { self .dot } context()->adapter_ptr()->has_16bit_storage()) {{\n "
378
+ skips += " GTEST_SKIP();"
379
+ skips += "}\n "
380
+ return skips
381
+
377
382
def gen_op_check_fn (self ) -> str :
378
383
op_name = self .f .func .name .unambiguous_name ()
379
384
op_check_fn = self .gen_decl (f"check_{ op_name } " ) + " {"
380
385
if self .should_prepack :
381
386
op_check_fn = self .gen_decl (f"prepacked_check_{ op_name } " ) + " {"
387
+ op_check_fn += self .gen_conditional_skips ()
382
388
op_check_fn += self .gen_graph_build_code ()
383
389
op_check_fn += self .gen_graph_exec_code ()
384
390
op_check_fn += self .check_graph_out (self .refs ["out" ])
@@ -391,19 +397,26 @@ def gen_op_check_fn(self) -> str:
391
397
##################################
392
398
393
399
test_fixture_template = """
394
- class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple<api::StorageType, api::GPUMemoryLayout>> {{
400
+ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple<at::ScalarType, api::StorageType, api::GPUMemoryLayout>> {{
395
401
protected:
396
402
ComputeGraph* graph;
397
403
at::ScalarType test_dtype = at::kFloat;
404
+ float rtol = 1e-5;
405
+ float atol = 1e-5;
398
406
399
407
void SetUp() override {{
400
408
GraphConfig config;
401
409
api::StorageType default_storage_type;
402
410
api::GPUMemoryLayout default_memory_layout;
403
- std::tie(default_storage_type, default_memory_layout) = GetParam();
411
+ std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam();
404
412
config.setStorageTypeOverride(default_storage_type);
405
413
config.setMemoryLayoutOverride(default_memory_layout);
406
414
graph = new ComputeGraph(config);
415
+
416
+ if (test_dtype == at::kHalf) {{
417
+ rtol = 1e-2;
418
+ atol = 1e-2;
419
+ }}
407
420
}}
408
421
409
422
void TearDown() override {{
@@ -420,7 +433,7 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple
420
433
421
434
422
435
class VkTestSuiteGen (TestSuiteGen ):
423
- def __init__ (self , op_reg_name : str , f : NativeFunction , inputs : List [ Any ] ):
436
+ def __init__ (self , op_reg_name : str , f : NativeFunction , inputs : VkTestSuite ):
424
437
super ().__init__ (f , inputs )
425
438
self .op_reg_name = op_reg_name
426
439
self .generator = ComputeGraphGen (self .op_reg_name , self .f , self .suite_def )
@@ -442,14 +455,16 @@ def generate_fixture_cpp(self) -> str:
442
455
)
443
456
444
457
def gen_parameterization (self ) -> str :
445
- storage_types = self .suite_def .supports ["storage_types" ]
446
- layouts = self .suite_def .supports ["layouts" ]
458
+ dtypes = self .suite_def .dtypes
459
+ storage_types = self .suite_def .storage_types
460
+ layouts = self .suite_def .layouts
447
461
448
462
return f"""
449
463
INSTANTIATE_TEST_SUITE_P(
450
- StorageLayoutCombos_ { self .op_name } ,
464
+ Combos_ { self .op_name } ,
451
465
GeneratedOpsTest_{ self .op_name } ,
452
466
::testing::Combine(
467
+ ::testing::Values({ ', ' .join (dtypes )} ),
453
468
::testing::Values({ ', ' .join (storage_types )} ),
454
469
::testing::Values({ ', ' .join (layouts )} )));
455
470
"""
@@ -494,9 +509,11 @@ def gen_parameterization(self) -> str:
494
509
return true;
495
510
}
496
511
bool is_close = at::allclose(t1, t2, rtol, atol);
497
- if (!is_close) {
498
- std::cout << "t1:" << t1 << std::endl;
499
- std::cout << "t2:" << t2 << std::endl;
512
+ if (!is_close && t1.numel() < 500) {
513
+ std::cout << "reference: " << std::endl;
514
+ print(t1, 150);
515
+ std::cout << "vulkan: " << std::endl;
516
+ print(t2, 150);
500
517
}
501
518
return is_close;
502
519
}
0 commit comments