Skip to content

Commit a10a8ef

Browse files
authored
Use rocblas_gemm_ex for batched gemms with broadcasted B (#1354)
Improves performance for 4/6 GEMMs used by huggingface BERT models with batch_size>1 by using a non-batched rocBLAS call for GEMMs where the B input has a broadcasted batch dimension. The four verify tests added reflect the actual configurations used by bert-base-cased, with varied batch sizes. Also adds a matcher to simplify_reshapes to move multibroadcasts after concats.
1 parent d78bcdf commit a10a8ef

File tree

5 files changed

+250
-1
lines changed

5 files changed

+250
-1
lines changed

src/simplify_reshapes.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,44 @@ struct find_nested_slice
271271
}
272272
};
273273

274+
struct find_concat_multibroadcasts
275+
{
276+
auto matcher() const
277+
{
278+
return match::name("concat")(match::all_of[match::inputs()](match::name("multibroadcast")));
279+
}
280+
281+
void apply(module& m, const match::matcher_result& mr) const
282+
{
283+
auto ins = mr.result;
284+
auto op = any_cast<op::concat>(ins->get_operator());
285+
auto out_lens = ins->get_shape().lens();
286+
auto inputs = ins->inputs();
287+
auto in_strides = inputs.front()->get_shape().strides();
288+
289+
// Only apply when concat axis is not a broadcasted dimension
290+
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
291+
return i->get_shape().strides()[op.axis] == 0;
292+
}))
293+
{
294+
return;
295+
}
296+
297+
// Use inputs of multibroadcast ops as inputs to new concat op
298+
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [](auto i) {
299+
return i->inputs().front();
300+
});
301+
302+
// Reduce axis by number of leading broadcasted dimensions
303+
if(inputs.front()->get_shape().lens().size() < out_lens.size())
304+
op.axis -= std::count(in_strides.begin(), in_strides.begin() + op.axis, 0);
305+
306+
auto concat = m.insert_instruction(ins, op, inputs);
307+
m.replace_instruction(
308+
ins, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), concat);
309+
}
310+
};
311+
274312
struct find_concat_transpose
275313
{
276314
auto matcher() const
@@ -764,6 +802,7 @@ void simplify_reshapes::apply(module& m) const
764802
find_reshaper{},
765803
find_transpose{},
766804
find_concat_transpose{},
805+
find_concat_multibroadcasts{},
767806
find_nested_convert{},
768807
find_nested_slice{},
769808
find_nested_concat{},

src/targets/gpu/gemm_impl.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,13 @@ void gemm_impl(context& ctx,
176176

177177
auto num_matrices = std::accumulate(
178178
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
179-
if(num_matrices == 1)
179+
if(num_matrices == 1 or (num_matrices > 1 and get_batch_stride(args[1]) == 0))
180180
{
181+
// If the batch dimension of B is broadcasted, then we can
182+
// multiply m by the batch_size and use rocblas_gemm_ex
183+
// instead of rocblas_gemm_strided_batched_ex.
184+
m *= num_matrices;
185+
181186
// the rocblas_gemm API handles inputs and output matrices as
182187
// column-major format. When doing a C = A * B, we actually do
183188
// C^T = (B^T) * (A^T). That is the reason we input args[1] as

test/simplify_reshapes_test.cpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,26 @@ inline std::vector<std::vector<std::size_t>> to_lens(const std::vector<migraphx:
4848
return result;
4949
}
5050

51+
migraphx::module make_concat_multibroadcast(const std::vector<size_t>& in_lens,
52+
const std::vector<size_t>& mbcast_lens,
53+
const int axis)
54+
{
55+
migraphx::module m;
56+
auto s = migraphx::shape{migraphx::shape::float_type, in_lens};
57+
auto x = m.add_parameter("x", s);
58+
auto y = m.add_parameter("y", s);
59+
auto z = m.add_parameter("z", s);
60+
auto xm =
61+
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mbcast_lens}}), x);
62+
auto ym =
63+
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mbcast_lens}}), y);
64+
auto zm =
65+
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mbcast_lens}}), z);
66+
auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", axis}}), xm, ym, zm);
67+
m.add_return({concat});
68+
return m;
69+
}
70+
5171
TEST_CASE(double_contig)
5272
{
5373
migraphx::program p;
@@ -337,6 +357,87 @@ TEST_CASE(nop_convert)
337357
EXPECT(std::distance(m.begin(), m.end()) == n - 1);
338358
}
339359

360+
TEST_CASE(concat_multibroadcasts1)
361+
{
362+
// Broadcasted batch dim, new axis < old axis
363+
std::vector<std::size_t> in_lens = {3, 4};
364+
std::vector<std::size_t> mbcast_lens = {2, 3, 4};
365+
const int axis = 2;
366+
auto m = make_concat_multibroadcast(in_lens, mbcast_lens, axis);
367+
auto out_shape = m.get_output_shapes().back();
368+
auto n = std::distance(m.begin(), m.end());
369+
run_pass(m);
370+
EXPECT(m.get_output_shapes().back().lens() == out_shape.lens());
371+
EXPECT(std::distance(m.begin(), m.end()) == n - 2);
372+
auto new_concat =
373+
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
374+
EXPECT(bool{new_concat != m.end()});
375+
auto cd = std::distance(m.begin(), new_concat);
376+
auto new_mb =
377+
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
378+
auto md = std::distance(m.begin(), new_mb);
379+
EXPECT(cd == md - 1);
380+
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
381+
}
382+
383+
TEST_CASE(concat_multibroadcasts2)
384+
{
385+
// Broadcasted middle dim, new axis == old axis
386+
std::vector<std::size_t> in_lens = {3, 1, 4};
387+
std::vector<std::size_t> mbcast_lens = {3, 2, 4};
388+
const int axis = 0;
389+
auto m = make_concat_multibroadcast(in_lens, mbcast_lens, axis);
390+
auto out_shape = m.get_output_shapes().back();
391+
auto n = std::distance(m.begin(), m.end());
392+
run_pass(m);
393+
EXPECT(m.get_output_shapes().back().lens() == out_shape.lens());
394+
EXPECT(std::distance(m.begin(), m.end()) == n - 2);
395+
auto new_concat =
396+
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
397+
EXPECT(bool{new_concat != m.end()});
398+
auto cd = std::distance(m.begin(), new_concat);
399+
auto new_mb =
400+
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
401+
auto md = std::distance(m.begin(), new_mb);
402+
EXPECT(cd == md - 1);
403+
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 0);
404+
}
405+
406+
TEST_CASE(concat_multibroadcasts3)
407+
{
408+
// Broadcasted middle dim, new axis == old axis
409+
std::vector<std::size_t> in_lens = {3, 1, 4};
410+
std::vector<std::size_t> mbcast_lens = {3, 2, 4};
411+
const int axis = 2;
412+
auto m = make_concat_multibroadcast(in_lens, mbcast_lens, axis);
413+
auto out_shape = m.get_output_shapes().back();
414+
auto n = std::distance(m.begin(), m.end());
415+
run_pass(m);
416+
EXPECT(m.get_output_shapes().back().lens() == out_shape.lens());
417+
EXPECT(std::distance(m.begin(), m.end()) == n - 2);
418+
auto new_concat =
419+
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
420+
EXPECT(bool{new_concat != m.end()});
421+
auto cd = std::distance(m.begin(), new_concat);
422+
auto new_mb =
423+
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
424+
auto md = std::distance(m.begin(), new_mb);
425+
EXPECT(cd == md - 1);
426+
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 2);
427+
}
428+
429+
TEST_CASE(concat_multibroadcasts4)
430+
{
431+
// Broadcasted batch dim, axis is broadcasted dim
432+
std::vector<std::size_t> in_lens = {3, 4};
433+
std::vector<std::size_t> mbcast_lens = {2, 3, 4};
434+
const int axis = 0;
435+
auto m = make_concat_multibroadcast(in_lens, mbcast_lens, axis);
436+
auto m1 = m;
437+
run_pass(m);
438+
EXPECT(m1 == m);
439+
}
440+
340441
TEST_CASE(concat_transpose1)
341442
{
342443
migraphx::module m;

test/verify/test_unbatched_gemm_1.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* The MIT License (MIT)
3+
*
4+
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to deal
8+
* in the Software without restriction, including without limitation the rights
9+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
* copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in
14+
* all copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22+
* THE SOFTWARE.
23+
*/
24+
25+
#include "verify_program.hpp"
26+
#include <migraphx/program.hpp>
27+
#include <migraphx/generate.hpp>
28+
#include <migraphx/make_op.hpp>
29+
#include <migraphx/apply_alpha_beta.hpp>
30+
struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1>
31+
{
32+
migraphx::program create_program() const
33+
{
34+
migraphx::program p;
35+
auto* mm = p.get_main_module();
36+
migraphx::shape m1_shape{migraphx::shape::float_type, {4, 384, 768}};
37+
migraphx::shape m2_shape{migraphx::shape::float_type, {768, 768}};
38+
migraphx::shape m3_shape{migraphx::shape::float_type, {4, 384, 2304}};
39+
auto l1 = mm->add_parameter("1", m1_shape);
40+
auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape));
41+
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 768, 768}}}),
42+
l2);
43+
auto l3 = mm->add_literal(migraphx::generate_literal(m2_shape));
44+
l3 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 768, 768}}}),
45+
l3);
46+
auto l4 = mm->add_literal(migraphx::generate_literal(m2_shape));
47+
l4 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 768, 768}}}),
48+
l4);
49+
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 2}}), l2, l3, l4);
50+
51+
auto l5 = mm->add_parameter("3", m3_shape);
52+
float alpha = 1.0f;
53+
float beta = 1.0f;
54+
migraphx::add_apply_alpha_beta(
55+
*mm, {l1, concat, l5}, migraphx::make_op("dot"), alpha, beta);
56+
return p;
57+
}
58+
};

test/verify/test_unbatched_gemm_2.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* The MIT License (MIT)
3+
*
4+
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to deal
8+
* in the Software without restriction, including without limitation the rights
9+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
* copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in
14+
* all copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22+
* THE SOFTWARE.
23+
*/
24+
25+
#include "verify_program.hpp"
26+
#include <migraphx/program.hpp>
27+
#include <migraphx/generate.hpp>
28+
#include <migraphx/make_op.hpp>
29+
#include <migraphx/apply_alpha_beta.hpp>
30+
struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2>
31+
{
32+
migraphx::program create_program() const
33+
{
34+
migraphx::program p;
35+
auto* mm = p.get_main_module();
36+
migraphx::shape m1_shape{migraphx::shape::float_type, {4, 384, 768}};
37+
migraphx::shape m2_shape{migraphx::shape::float_type, {768, 768}};
38+
auto l1 = mm->add_parameter("1", m1_shape);
39+
auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape));
40+
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 768, 768}}}),
41+
l2);
42+
43+
mm->add_instruction(migraphx::make_op("dot"), l1, l2);
44+
return p;
45+
}
46+
};

0 commit comments

Comments
 (0)