@@ -66,7 +66,7 @@ static bool check_err(const Tensor<T>& ref, const Tensor<T>& result)
66
66
67
67
bool TestGroupedGemm (DeviceGroupedGemmPtr_& groupedGemmPtr)
68
68
{
69
- int group_count = 4 ;
69
+ int group_count = rand () % 10 + 1 ;
70
70
71
71
// GEMM shape
72
72
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
@@ -77,9 +77,9 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
77
77
78
78
for (int i = 0 ; i < group_count; i++)
79
79
{
80
- int M = 256 + 256 * i ;
81
- int N = 128 + 128 * i ;
82
- int K = 128 + 64 * i ;
80
+ int M = 256 + 256 * ( rand () % 10 ) ;
81
+ int N = 256 + 256 * ( rand () % 10 ) ;
82
+ int K = 128 + 128 * ( rand () % 10 ) ;
83
83
84
84
int AStride = std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value ? K : M;
85
85
int BStride = std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value ? N : K;
@@ -132,8 +132,8 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
132
132
c_device_tensors.emplace_back (Tensor<CDataType>(f_host_tensor_descriptor (
133
133
gemm_shapes[i].M , gemm_shapes[i].N , gemm_shapes[i].StrideC , CLayout{})));
134
134
135
- a_tensors[i].GenerateTensorValue (GeneratorTensor_3 <ADataType>{0.0 , 1.0 });
136
- b_tensors[i].GenerateTensorValue (GeneratorTensor_3 <BDataType>{-0. 5 , 0. 5 });
135
+ a_tensors[i].GenerateTensorValue (GeneratorTensor_2 <ADataType>{- 5 , 5 });
136
+ b_tensors[i].GenerateTensorValue (GeneratorTensor_2 <BDataType>{-5 , 5 });
137
137
}
138
138
139
139
for (int i = 0 ; i < gemm_shapes.size (); i++)
@@ -181,6 +181,11 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
181
181
b_element_op,
182
182
c_element_op);
183
183
184
+ if (!groupedGemmPtr->IsSupportedArgument (argument_ptr.get ()))
185
+ {
186
+ return false ;
187
+ }
188
+
184
189
ref_invoker.Run (ref_argument);
185
190
186
191
bool res = check_err (c_device_tensors[i], c_host_tensors[i]);
@@ -210,4 +215,6 @@ int main()
210
215
}
211
216
212
217
std::cout << " TestGroupedGemm ..... " << (res ? " SUCCESS" : " FAILURE" ) << std::endl;
218
+
219
+ return res ? 0 : 1 ;
213
220
}
0 commit comments