-
Notifications
You must be signed in to change notification settings - Fork 104
Multibroadcast find_mul_conv #1384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This build is not recommended to merge 🔴 |
src/simplify_algebra.cpp
Outdated
{ | ||
if(invalid_sl(i)) | ||
invalid_case = true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be simpler to write this as:
auto is_broadcasted_axis = [](auto len, auto stride) {
return len == 1 or stride == 0;
};
if (not is_broadcasted_axis(a_lens.front(), a_strides.front()))
return;
if (not std::equal(a_lens.begin()+2, a_lens.end(), a_strides.begin()+2, a_strides.end(), is_broadcasted_axis))
return;
src/simplify_algebra.cpp
Outdated
|
||
// check broadcasted along channels | ||
auto a_lens = a_ins->get_shape().lens(); | ||
auto a_strides = a_ins->get_shape().strides(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should be const auto&
.
src/simplify_algebra.cpp
Outdated
invalid_case = true; | ||
} | ||
|
||
if(invalid_sl(0) or a_strides.at(1) != 1 or invalid_case) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should check a_strides.at(1) != 1
before checking the broadcasted axis. Also, we should probably check the rank as well(ie a_lens.size() > 2
) before we start checking everything else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't rank >= 4
to be a valid convolution output?
Change find_mul_conv to work with multibroadcast also. Checks the strides instead of the broadcast axis.