Skip to content
37 changes: 30 additions & 7 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@ auto conv_const_weights()

auto reduction() { return match::name_contains("reduce"); }

// conv(x, w) * a => conv(x, a * w)
struct find_mul_conv
{
auto matcher() const
{
return match::name("mul")(match::either_arg(0, 1)(conv_const_weights().bind("conv"),
match::name("broadcast").bind("a")));
return match::name("mul")(
match::either_arg(0, 1)(conv_const_weights().bind("conv"),
match::name("broadcast", "multibroadcast").bind("a")));
}

void apply(module& m, const match::matcher_result& r) const
Expand All @@ -72,14 +74,35 @@ struct find_mul_conv
auto a_ins = r.instructions["a"];
auto w_ins = r.instructions["w"];

auto broadcast_op = any_cast<op::broadcast>(a_ins->get_operator());
if(broadcast_op.axis != 1)
const auto& a_input_lens = a_ins->inputs().front()->get_shape().lens();

std::size_t num_not_one_dims = std::count_if(
a_input_lens.cbegin(), a_input_lens.cend(), [](auto dim) { return dim != 1; });
if(num_not_one_dims > 1)
return;

// check broadcasted along channels
const auto& a_lens = a_ins->get_shape().lens();
const auto& a_strides = a_ins->get_shape().strides();

auto is_broadcasted_axis = [](auto len, auto stride) { return len == 1 or stride == 0; };

if(a_strides.at(1) != 1)
return;

if(not is_broadcasted_axis(a_lens.front(), a_strides.front()))
return;

if(not std::equal(a_lens.begin() + 2,
a_lens.end(),
a_strides.begin() + 2,
a_strides.end(),
is_broadcasted_axis))
return;

auto sq = m.insert_instruction(ins, make_op("squeeze"), a_ins->inputs().front());
auto new_a = m.insert_instruction(
ins,
make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}),
a_ins->inputs().front());
ins, make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}), sq);
auto new_mul = m.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = m.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
Expand Down
99 changes: 99 additions & 0 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,105 @@ TEST_CASE(simplify_mul_conv1)
EXPECT(new_conv->outputs().front()->name() != "mul");
}

TEST_CASE(simplify_mul_conv2)
{
migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
auto w =
m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
auto conv = m.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto a = m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
auto unsq_a = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), a);
auto b = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 256, 14, 14}}}), unsq_a);
auto mul = m.add_instruction(migraphx::make_op("mul"), conv, b);
m.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul");
run_pass(m);
auto new_conv =
std::find_if(m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; });
EXPECT(new_conv->outputs().front()->name() != "mul");
}

// len = 1 case
TEST_CASE(simplify_mul_conv3)
{
migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
auto w =
m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
auto conv = m.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto a = m.add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {256, 1, 1}, {1, 18, 1}}));
auto b =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 256, 14, 14}}}), a);
auto mul = m.add_instruction(migraphx::make_op("mul"), conv, b);
m.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul");
run_pass(m);
auto new_conv =
std::find_if(m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; });
EXPECT(new_conv->outputs().front()->name() != "mul");
}

// Previously broadcasted literal case, should skip
TEST_CASE(simplify_mul_conv_skip1)
{
migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
auto w =
m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
auto conv = m.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto a = m.add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {256, 14, 14}, {1, 0, 0}}));
auto b = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 256, 14, 14}}}), a);
auto mul = m.add_instruction(migraphx::make_op("mul"), conv, b);
m.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul");
run_pass(m);
auto new_conv =
std::find_if(m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; });
EXPECT(new_conv->outputs().front()->name() == "mul");
}

// Another previously broadcasted literal case, should skip
TEST_CASE(simplify_mul_conv_skip2)
{
migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
auto w =
m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
auto conv = m.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto a = m.add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {256, 14, 14}, {1, 0, 0}}));
auto b =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 256, 14, 14}}}), a);
auto mul = m.add_instruction(migraphx::make_op("mul"), conv, b);
m.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul");
run_pass(m);
auto new_conv =
std::find_if(m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; });
EXPECT(new_conv->outputs().front()->name() == "mul");
}

TEST_CASE(simplify_mul_slice_conv1)
{
migraphx::module m1;
Expand Down