Skip to content

Commit fe6ce55

Browse files
authored
Grouped gemm test fix (pytorch#150)
* fixed test: return res; rand gemm shapes * fixed return
1 parent 313bbea commit fe6ce55

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

test/grouped_gemm/grouped_gemm_fp16.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ static bool check_err(const Tensor<T>& ref, const Tensor<T>& result)
6666

6767
bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
6868
{
69-
int group_count = 4;
69+
int group_count = rand() % 10 + 1;
7070

7171
// GEMM shape
7272
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
@@ -77,9 +77,9 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
7777

7878
for(int i = 0; i < group_count; i++)
7979
{
80-
int M = 256 + 256 * i;
81-
int N = 128 + 128 * i;
82-
int K = 128 + 64 * i;
80+
int M = 256 + 256 * (rand() % 10);
81+
int N = 256 + 256 * (rand() % 10);
82+
int K = 128 + 128 * (rand() % 10);
8383

8484
int AStride = std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value ? K : M;
8585
int BStride = std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value ? N : K;
@@ -132,8 +132,8 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
132132
c_device_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor(
133133
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{})));
134134

135-
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
136-
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
135+
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
136+
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
137137
}
138138

139139
for(int i = 0; i < gemm_shapes.size(); i++)
@@ -181,6 +181,11 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
181181
b_element_op,
182182
c_element_op);
183183

184+
if(!groupedGemmPtr->IsSupportedArgument(argument_ptr.get()))
185+
{
186+
return false;
187+
}
188+
184189
ref_invoker.Run(ref_argument);
185190

186191
bool res = check_err(c_device_tensors[i], c_host_tensors[i]);
@@ -210,4 +215,6 @@ int main()
210215
}
211216

212217
std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
218+
219+
return res ? 0 : 1;
213220
}

0 commit comments

Comments
 (0)