-
Notifications
You must be signed in to change notification settings - Fork 110
Fast softmax #1290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fast softmax #1290
Conversation
Codecov Report
@@ Coverage Diff @@
## develop #1290 +/- ##
===========================================
- Coverage 93.15% 92.86% -0.29%
===========================================
Files 459 446 -13
Lines 15604 14908 -696
===========================================
- Hits 14536 13845 -691
+ Misses 1068 1063 -5
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sort out why we're getting NaNs from test cases that are failing.
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input); | ||
auto batch_sum = | ||
r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input); | ||
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; })(output, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this change is breaking some expected output in our tests. After a brief talk with @kahmed10 it looks like we were accounting for overflow on this and now getting NaN in some of the outputs. The change itself looks sane, but I think we need to delve into this testcase and make sure what you've done doesn't break correctness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which test is failing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Odd looks like that resolved when develop was merged in. It was failing on clang_release in the previous commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was saying
test_py_3.8_backend .......................................................***Failed 550.96 sec
] ======================================================================
[2022-07-07T03:08:09.031Z] FAIL: test_softmax_large_number_cpu (__main__.OnnxBackendNodeModelTest)
[2022-07-07T03:08:09.031Z] ----------------------------------------------------------------------
...
[2022-07-07T03:08:09.032Z] x and y nan location mismatch:
[2022-07-07T03:08:09.032Z] x: array([[0.032059, 0.087144, 0.236883, 0.643914],
[2022-07-07T03:08:09.032Z] [0.032059, 0.087144, 0.236883, 0.643914]], dtype=float32)
[2022-07-07T03:08:09.032Z] y: array([[0.032059, 0.087144, 0.236883, 0.643914],
[2022-07-07T03:08:09.032Z] [ nan, nan, nan, nan]], dtype=float32)
[2022-07-07T03:08:09.032Z] ======================================================================
[2022-07-07T03:08:09.032Z] FAIL: test_shufflenet_cpu (__main__.OnnxBackendRealModelTest)
[2022-07-07T03:08:09.032Z] ----------------------------------------------------------------------
...
[2022-07-07T03:08:09.032Z] raise AssertionError(msg)
[2022-07-07T03:08:09.032Z] AssertionError:
[2022-07-07T03:08:09.032Z] Not equal to tolerance rtol=0.001, atol=1e-05
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see so the tests is using very large inputs. On bert(and most models), the inputs shouldn't be so larger. I am not sure how we can configure this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since Ive gotten back online, this has been bothering me so I went over softmax a bit.
I'm curious, why did we want to change this to remove the subtraction? The lookup to find the batch_max should be O(n) so we're worst case 2 O(n) at runtime unless the concern here is we're doing this O(n) for every item in the batch. Then again I can see the motivation here for the 2x performance boost. Were we getting large Nan results before?
Thinking about softmax more, couldn't we do something like rearrange the top exponential and distributing it into the lower sum, instead of doing the linear lookup. Multiplying each output by exp(-xi)/ exp(-xi) you'd guarantee that all your values in the sum must be less than exp(xi - xj) and we're less likely to roll over.
auto batch_sum = r.reduce( op::sum{}, 0, [](auto x) { return migraphx::convert<float>(migraphx::exp(xi - xj)); })(input);
r.inner([&](auto& y, auto x) { y = 1 / batch_sum; })(output, input);
my syntax for this is probably off but the idea here would be to reduce one migraphx::exp() call and then add back the extra subtraction into the sum using the current index instead of doing the lookup..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious, why did we want to change this to remove the subtraction?
We dont have to do an extra reduction.
Were we getting large Nan results before?
I dont think we were except for this test case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Multiplying each output by exp(-xi)/ exp(-xi) you'd guarantee that all your values in the sum must be less than exp(xi - xj) and we're less likely to roll over.
Hmm thats an interesting idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Multiplying each output by exp(-xi)/ exp(-xi) you'd guarantee that all your values in the sum must be less than exp(xi - xj) and we're less likely to roll over.
Actually this wont work as we will need a sum for each element, instead of reusing the sum. Another idea I have is maybe we can just take the first element instead of the max.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Multiplying each output by exp(-xi)/ exp(-xi) you'd guarantee that all your values in the sum must be less than exp(xi - xj) and we're less likely to roll over.
Actually this wont work as we will need a sum for each element, instead of reusing the sum. Another idea I have is maybe we can just take the first element instead of the max.
Is there a way for use to just use:
normalized_sum = batch_sum / exp( - xi);
For each element? That would save us from recalculating a sum every run and still avoid the O(n) penalty
It gives us the added bonus to disern if were blowing up in our numerator or denominator in the expression.
If the denominator's blowing up we could probably do some sort of transformation on the sum part. There's probably an identity for the sum of exponential we can leverage to simplify and input the powers at the end.
If numerator's blown up, well we don't worry since now we just have to worry if the value of 1/ normalized_sum is close to -FLOAT_MAX.
My intuition here is that this operator has to output a value between 0 and 1 since we get a probability at the end, so we know that we'd have to normalize things within that range and just mitigate the intermediate representation of the values regardless if the input is massive.
We cant just do |
What's large though? Doing a quick check with FLOAT_MAX/MIN, taking the natural log of those values rounded are around 88 and -87 respectively, anything above those in exp(x) surpass floats container size. Unless we do some sort of tap dance with natural logarithm and do some algebra here with the final output. Whats difficult here is that softmax itself is not scale invariant so we can't just scale the input here either before taking the exponential as the sum is a tad cumbersome. It is infact translation invariant though but that doesn't help us if the input values are outside the larger range, otherwise we could just convert it to some better representation and covert it back. |
In the test case, it was 10000, and in the AgentModel I was seeing even larger values. |
This build is not recommended to merge 🔴 |
Wowza on the 7.8% jump for bert onnx, that's huge!
That may be an issue. Curious why its such a large drop with mobilenet_1 too |
This build is not recommended to merge 🔴 |
auto batch_sum = r.reduce(op::sum{}, 0, [&](auto x) { | ||
return migraphx::convert<float>(migraphx::exp(x - c)); | ||
})(input); | ||
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - c) / batch_sum; })(output, input); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should still being using the max in the calculation, but I've seen an alternative to still save us from blowing up our exp() and prevents overflow and solve the correctness issue with Naans
calculate log(output) instead and then convert it back
This gives us the following:
//since
batch_sum = sum(exp(input))
max = max(input)
batch_sum_sub_max = exp(input - max)
output = exp(x) / batch_sum = exp(x-max) / batch_sum_sub_max
//take natural log of the sum
ln(output) = (x - max) - log(batch_sum_sub_max)
//thus
output = exp( x - max - log(batch_sum_sub_max))
Refs:
https://stackoverflow.com/questions/9906136/implementation-of-a-softmax-activation-function-for-neural-networks
https://lingpipe-blog.com/2009/06/25/log-sum-of-exponentials/
The issue here now is will we have a speedup then without the max? Unless we have a fast way of doing the log() operator, or even exp. The nice thing here is we can split the final calculation once we have the input max, and sum though.
Not sure how to do that right now in our migraphx context and for this kernel, but I assume thats taken care of between the multiple inner statements?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like with the flag we'll enable this on a case by case basis. I suppose this will give us a speedup on models that we know the input to the softmax wont be exceeding the (-83, 82) range we saw that blew up the calculation
Maybe they'll need to be logic ontop of this to analyze toggling this speedup later down the road.
No description provided.