Skip to content

Commit 333860c

Browse files
authored
Reduce problem size of unbatched_gemm tests (#1383)
The verify tests from pr #1354 were still causing some codecov timeouts after merge. This PR further reduces the problem sizes to avoid these failures.
1 parent 4b76dd0 commit 333860c

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

test/verify/test_unbatched_gemm_1.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,18 @@ struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1>
3333
{
3434
migraphx::program p;
3535
auto* mm = p.get_main_module();
36-
migraphx::shape m1_shape{migraphx::shape::float_type, {4, 384, 768}};
37-
migraphx::shape m2_shape{migraphx::shape::float_type, {768, 768}};
38-
migraphx::shape m3_shape{migraphx::shape::float_type, {4, 384, 2304}};
36+
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 32, 64}};
37+
migraphx::shape m2_shape{migraphx::shape::float_type, {64, 64}};
38+
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 32, 192}};
3939
auto l1 = mm->add_parameter("1", m1_shape);
4040
auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape));
41-
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 768, 768}}}),
41+
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 64, 64}}}),
4242
l2);
4343
auto l3 = mm->add_literal(migraphx::generate_literal(m2_shape));
44-
l3 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 768, 768}}}),
44+
l3 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 64, 64}}}),
4545
l3);
4646
auto l4 = mm->add_literal(migraphx::generate_literal(m2_shape));
47-
l4 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 768, 768}}}),
47+
l4 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 64, 64}}}),
4848
l4);
4949
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), l2, l3, l4);
5050

test/verify/test_unbatched_gemm_2.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2>
3333
{
3434
migraphx::program p;
3535
auto* mm = p.get_main_module();
36-
migraphx::shape m1_shape{migraphx::shape::float_type, {4, 384, 768}};
37-
migraphx::shape m2_shape{migraphx::shape::float_type, {768, 768}};
36+
migraphx::shape m1_shape{migraphx::shape::float_type, {4, 32, 64}};
37+
migraphx::shape m2_shape{migraphx::shape::float_type, {64, 64}};
3838
auto l1 = mm->add_parameter("1", m1_shape);
3939
auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape));
40-
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 768, 768}}}),
40+
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 64, 64}}}),
4141
l2);
4242

4343
mm->add_instruction(migraphx::make_op("dot"), l1, l2);

0 commit comments

Comments
 (0)