@@ -48,6 +48,26 @@ inline std::vector<std::vector<std::size_t>> to_lens(const std::vector<migraphx:
48
48
return result;
49
49
}
50
50
51
+ migraphx::module make_concat_multibroadcast (const std::vector<size_t >& in_lens,
52
+ const std::vector<size_t >& mbcast_lens,
53
+ const int axis)
54
+ {
55
+ migraphx::module m;
56
+ auto s = migraphx::shape{migraphx::shape::float_type, in_lens};
57
+ auto x = m.add_parameter (" x" , s);
58
+ auto y = m.add_parameter (" y" , s);
59
+ auto z = m.add_parameter (" z" , s);
60
+ auto xm =
61
+ m.add_instruction (migraphx::make_op (" multibroadcast" , {{" out_lens" , mbcast_lens}}), x);
62
+ auto ym =
63
+ m.add_instruction (migraphx::make_op (" multibroadcast" , {{" out_lens" , mbcast_lens}}), y);
64
+ auto zm =
65
+ m.add_instruction (migraphx::make_op (" multibroadcast" , {{" out_lens" , mbcast_lens}}), z);
66
+ auto concat = m.add_instruction (migraphx::make_op (" concat" , {{" axis" , axis}}), xm, ym, zm);
67
+ m.add_return ({concat});
68
+ return m;
69
+ }
70
+
51
71
TEST_CASE (double_contig)
52
72
{
53
73
migraphx::program p;
@@ -337,6 +357,87 @@ TEST_CASE(nop_convert)
337
357
EXPECT (std::distance (m.begin (), m.end ()) == n - 1 );
338
358
}
339
359
360
+ TEST_CASE (concat_multibroadcasts1)
361
+ {
362
+ // Broadcasted batch dim, new axis < old axis
363
+ std::vector<std::size_t > in_lens = {3 , 4 };
364
+ std::vector<std::size_t > mbcast_lens = {2 , 3 , 4 };
365
+ const int axis = 2 ;
366
+ auto m = make_concat_multibroadcast (in_lens, mbcast_lens, axis);
367
+ auto out_shape = m.get_output_shapes ().back ();
368
+ auto n = std::distance (m.begin (), m.end ());
369
+ run_pass (m);
370
+ EXPECT (m.get_output_shapes ().back ().lens () == out_shape.lens ());
371
+ EXPECT (std::distance (m.begin (), m.end ()) == n - 2 );
372
+ auto new_concat =
373
+ std::find_if (m.begin (), m.end (), [](auto ins) { return ins.name () == " concat" ; });
374
+ EXPECT (bool {new_concat != m.end ()});
375
+ auto cd = std::distance (m.begin (), new_concat);
376
+ auto new_mb =
377
+ std::find_if (m.begin (), m.end (), [](auto ins) { return ins.name () == " multibroadcast" ; });
378
+ auto md = std::distance (m.begin (), new_mb);
379
+ EXPECT (cd == md - 1 );
380
+ EXPECT (migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator ()).axis == 1 );
381
+ }
382
+
383
+ TEST_CASE (concat_multibroadcasts2)
384
+ {
385
+ // Broadcasted middle dim, new axis == old axis
386
+ std::vector<std::size_t > in_lens = {3 , 1 , 4 };
387
+ std::vector<std::size_t > mbcast_lens = {3 , 2 , 4 };
388
+ const int axis = 0 ;
389
+ auto m = make_concat_multibroadcast (in_lens, mbcast_lens, axis);
390
+ auto out_shape = m.get_output_shapes ().back ();
391
+ auto n = std::distance (m.begin (), m.end ());
392
+ run_pass (m);
393
+ EXPECT (m.get_output_shapes ().back ().lens () == out_shape.lens ());
394
+ EXPECT (std::distance (m.begin (), m.end ()) == n - 2 );
395
+ auto new_concat =
396
+ std::find_if (m.begin (), m.end (), [](auto ins) { return ins.name () == " concat" ; });
397
+ EXPECT (bool {new_concat != m.end ()});
398
+ auto cd = std::distance (m.begin (), new_concat);
399
+ auto new_mb =
400
+ std::find_if (m.begin (), m.end (), [](auto ins) { return ins.name () == " multibroadcast" ; });
401
+ auto md = std::distance (m.begin (), new_mb);
402
+ EXPECT (cd == md - 1 );
403
+ EXPECT (migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator ()).axis == 0 );
404
+ }
405
+
406
+ TEST_CASE (concat_multibroadcasts3)
407
+ {
408
+ // Broadcasted middle dim, new axis == old axis
409
+ std::vector<std::size_t > in_lens = {3 , 1 , 4 };
410
+ std::vector<std::size_t > mbcast_lens = {3 , 2 , 4 };
411
+ const int axis = 2 ;
412
+ auto m = make_concat_multibroadcast (in_lens, mbcast_lens, axis);
413
+ auto out_shape = m.get_output_shapes ().back ();
414
+ auto n = std::distance (m.begin (), m.end ());
415
+ run_pass (m);
416
+ EXPECT (m.get_output_shapes ().back ().lens () == out_shape.lens ());
417
+ EXPECT (std::distance (m.begin (), m.end ()) == n - 2 );
418
+ auto new_concat =
419
+ std::find_if (m.begin (), m.end (), [](auto ins) { return ins.name () == " concat" ; });
420
+ EXPECT (bool {new_concat != m.end ()});
421
+ auto cd = std::distance (m.begin (), new_concat);
422
+ auto new_mb =
423
+ std::find_if (m.begin (), m.end (), [](auto ins) { return ins.name () == " multibroadcast" ; });
424
+ auto md = std::distance (m.begin (), new_mb);
425
+ EXPECT (cd == md - 1 );
426
+ EXPECT (migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator ()).axis == 2 );
427
+ }
428
+
429
+ TEST_CASE (concat_multibroadcasts4)
430
+ {
431
+ // Broadcasted batch dim, axis is broadcasted dim
432
+ std::vector<std::size_t > in_lens = {3 , 4 };
433
+ std::vector<std::size_t > mbcast_lens = {2 , 3 , 4 };
434
+ const int axis = 0 ;
435
+ auto m = make_concat_multibroadcast (in_lens, mbcast_lens, axis);
436
+ auto m1 = m;
437
+ run_pass (m);
438
+ EXPECT (m1 == m);
439
+ }
440
+
340
441
TEST_CASE (concat_transpose1)
341
442
{
342
443
migraphx::module m;
0 commit comments