Skip to content

Conversation

pfultz2
Copy link
Collaborator

@pfultz2 pfultz2 commented Jul 7, 2022

No description provided.

@codecov
Copy link

codecov bot commented Jul 7, 2022

Codecov Report

Merging #1290 (f5df3d6) into develop (c9ffb38) will decrease coverage by 0.28%.
The diff coverage is n/a.

❗ Current head f5df3d6 differs from pull request most recent head e34c152. Consider uploading reports for the commit e34c152 to get more accurate results

@@             Coverage Diff             @@
##           develop    #1290      +/-   ##
===========================================
- Coverage    93.15%   92.86%   -0.29%     
===========================================
  Files          459      446      -13     
  Lines        15604    14908     -696     
===========================================
- Hits         14536    13845     -691     
+ Misses        1068     1063       -5     
Impacted Files Coverage Δ
src/include/migraphx/op/pow.hpp 0.00% <0.00%> (-100.00%) ⬇️
src/include/migraphx/assert.hpp 28.57% <0.00%> (-8.93%) ⬇️
src/include/migraphx/match/gelu_tanh.hpp 91.66% <0.00%> (-8.34%) ⬇️
src/targets/cpu/copy.cpp 9.09% <0.00%> (-7.58%) ⬇️
src/generate.cpp 94.73% <0.00%> (-5.27%) ⬇️
src/include/migraphx/op/multinomial.hpp 95.83% <0.00%> (-4.17%) ⬇️
src/dom_info.cpp 96.15% <0.00%> (-3.85%) ⬇️
src/include/migraphx/errors.hpp 71.42% <0.00%> (-3.58%) ⬇️
src/include/migraphx/argument.hpp 78.57% <0.00%> (-2.68%) ⬇️
src/tf/parse_binary_op.cpp 83.33% <0.00%> (-2.39%) ⬇️
... and 295 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@TedThemistokleous TedThemistokleous self-assigned this Jul 29, 2022
Copy link
Collaborator

@TedThemistokleous TedThemistokleous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sort out why we're getting NaNs from test cases that are failing.

auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input);
auto batch_sum =
r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input);
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; })(output,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this change is breaking some expected output in our tests. After a brief talk with @kahmed10 it looks like we were accounting for overflow on this and now getting NaN in some of the outputs. The change itself looks sane, but I think we need to delve into this testcase and make sure what you've done doesn't break correctness.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which test is failing?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Odd looks like that resolved when develop was merged in. It was failing on clang_release in the previous commit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was saying

test_py_3.8_backend .......................................................***Failed  550.96 sec
] ======================================================================

[2022-07-07T03:08:09.031Z] FAIL: test_softmax_large_number_cpu (__main__.OnnxBackendNodeModelTest)

[2022-07-07T03:08:09.031Z] ----------------------------------------------------------------------
...

[2022-07-07T03:08:09.032Z] x and y nan location mismatch:

[2022-07-07T03:08:09.032Z]  x: array([[0.032059, 0.087144, 0.236883, 0.643914],

[2022-07-07T03:08:09.032Z]        [0.032059, 0.087144, 0.236883, 0.643914]], dtype=float32)

[2022-07-07T03:08:09.032Z]  y: array([[0.032059, 0.087144, 0.236883, 0.643914],

[2022-07-07T03:08:09.032Z]        [     nan,      nan,      nan,      nan]], dtype=float32)
[2022-07-07T03:08:09.032Z] ======================================================================

[2022-07-07T03:08:09.032Z] FAIL: test_shufflenet_cpu (__main__.OnnxBackendRealModelTest)

[2022-07-07T03:08:09.032Z] ----------------------------------------------------------------------
...
[2022-07-07T03:08:09.032Z]     raise AssertionError(msg)

[2022-07-07T03:08:09.032Z] AssertionError: 

[2022-07-07T03:08:09.032Z] Not equal to tolerance rtol=0.001, atol=1e-05

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see so the tests is using very large inputs. On bert(and most models), the inputs shouldn't be so larger. I am not sure how we can configure this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Ive gotten back online, this has been bothering me so I went over softmax a bit.

I'm curious, why did we want to change this to remove the subtraction? The lookup to find the batch_max should be O(n) so we're worst case 2 O(n) at runtime unless the concern here is we're doing this O(n) for every item in the batch. Then again I can see the motivation here for the 2x performance boost. Were we getting large Nan results before?

Thinking about softmax more, couldn't we do something like rearrange the top exponential and distributing it into the lower sum, instead of doing the linear lookup. Multiplying each output by exp(-xi)/ exp(-xi) you'd guarantee that all your values in the sum must be less than exp(xi - xj) and we're less likely to roll over.

auto batch_sum = r.reduce( op::sum{}, 0, [](auto x) { return migraphx::convert<float>(migraphx::exp(xi - xj)); })(input);
r.inner([&](auto& y, auto x) { y = 1 / batch_sum; })(output, input);

my syntax for this is probably off but the idea here would be to reduce one migraphx::exp() call and then add back the extra subtraction into the sum using the current index instead of doing the lookup..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious, why did we want to change this to remove the subtraction?

We dont have to do an extra reduction.

Were we getting large Nan results before?

I dont think we were except for this test case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multiplying each output by exp(-xi)/ exp(-xi) you'd guarantee that all your values in the sum must be less than exp(xi - xj) and we're less likely to roll over.

Hmm thats an interesting idea.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multiplying each output by exp(-xi)/ exp(-xi) you'd guarantee that all your values in the sum must be less than exp(xi - xj) and we're less likely to roll over.

Actually this wont work as we will need a sum for each element, instead of reusing the sum. Another idea I have is maybe we can just take the first element instead of the max.

Copy link
Collaborator

@TedThemistokleous TedThemistokleous Aug 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multiplying each output by exp(-xi)/ exp(-xi) you'd guarantee that all your values in the sum must be less than exp(xi - xj) and we're less likely to roll over.

Actually this wont work as we will need a sum for each element, instead of reusing the sum. Another idea I have is maybe we can just take the first element instead of the max.

Is there a way for use to just use:

normalized_sum = batch_sum / exp( - xi);

For each element? That would save us from recalculating a sum every run and still avoid the O(n) penalty

It gives us the added bonus to disern if were blowing up in our numerator or denominator in the expression.

If the denominator's blowing up we could probably do some sort of transformation on the sum part. There's probably an identity for the sum of exponential we can leverage to simplify and input the powers at the end.

If numerator's blown up, well we don't worry since now we just have to worry if the value of 1/ normalized_sum is close to -FLOAT_MAX.

My intuition here is that this operator has to output a value between 0 and 1 since we get a probability at the end, so we know that we'd have to normalize things within that range and just mitigate the intermediate representation of the values regardless if the input is massive.

@pfultz2
Copy link
Collaborator Author

pfultz2 commented Aug 16, 2022

Is there a way for use to just use:
normalized_sum = batch_sum / exp( - xi);

We cant just do exp(x) or exp(-x) as this will cause nans for large values of x.

@TedThemistokleous
Copy link
Collaborator

TedThemistokleous commented Aug 16, 2022

Is there a way for use to just use:
normalized_sum = batch_sum / exp( - xi);

We cant just do exp(x) or exp(-x) as this will cause nans for large values of x.

What's large though? Doing a quick check with FLOAT_MAX/MIN, taking the natural log of those values rounded are around 88 and -87 respectively, anything above those in exp(x) surpass floats container size. Unless we do some sort of tap dance with natural logarithm and do some algebra here with the final output.

Whats difficult here is that softmax itself is not scale invariant so we can't just scale the input here either before taking the exponential as the sum is a tad cumbersome.

It is infact translation invariant though but that doesn't help us if the input values are outside the larger range, otherwise we could just convert it to some better representation and covert it back.

@pfultz2
Copy link
Collaborator Author

pfultz2 commented Aug 16, 2022

taking the natural log of those values rounded are around 88 and -87 respectively, anything above those in exp(x) surpass floats container size

In the test case, it was 10000, and in the AgentModel I was seeing even larger values.

@migraphx-bot
Copy link
Collaborator

Test Rate new
70687f
Rate old
cb5368
Diff Compare
torchvision-resnet50 2,224.40 2,224.84 -0.02%
torchvision-resnet50_fp16 4,749.29 4,745.80 0.07%
torchvision-alexnet 4,983.80 4,976.36 0.15%
torchvision-alexnet_fp16 26,341.78 26,143.04 0.76% 🔆
torchvision-densenet121 1,631.48 1,630.56 0.06%
torchvision-densenet121_fp16 2,526.37 2,526.69 -0.01%
torchvision-inceptionv3 1,096.64 1,096.31 0.03%
torchvision-inceptionv3_fp16 1,986.90 1,984.36 0.13%
torchvision-vgg16 896.32 895.76 0.06%
torchvision-vgg16_fp16 1,727.37 1,726.23 0.07%
cadene-inceptionv4 528.56 528.10 0.09%
cadene-resnext64x4 578.05 577.05 0.17%
slim-mobilenet 6,400.63 6,396.87 0.06%
slim-nasnetalarge 203.50 203.32 0.09%
slim-resnet50v2 2,431.49 2,429.83 0.07%
bert-mrpc-onnx 689.27 639.40 7.80% 🔆
bert-mrpc-tf 300.30 296.01 1.45% 🔆
pytorch-examples-wlang-gru 229.44 229.72 -0.12%
pytorch-examples-wlang-lstm 306.54 306.02 0.17%
torchvision-resnet50_1 512.79 513.42 -0.12%
torchvision-inceptionv3_1 302.82 303.19 -0.12%
torchvision-vgg16_1 463.49 463.62 -0.03%
cadene-dpn92_1 309.46 307.77 0.55%
cadene-resnext101_1 229.97 238.31 -3.50%
slim-vgg16_1 64.04 64.01 0.05%
slim-mobilenet_1 1,964.29 1,992.50 -1.42% 🔴
slim-inceptionv4_1 195.54 195.90 -0.18%
onnx-taau-downsample 258.95 258.89 0.02%

This build is not recommended to merge 🔴

@TedThemistokleous
Copy link
Collaborator

Wowza on the 7.8% jump for bert onnx, that's huge!

taking the natural log of those values rounded are around 88 and -87 respectively, anything above those in exp(x) surpass floats container size

In the test case, it was 10000, and in the AgentModel I was seeing even larger values.

That may be an issue. Curious why its such a large drop with mobilenet_1 too

@migraphx-bot
Copy link
Collaborator

migraphx-bot commented Aug 17, 2022

Test Rate new
e34c15
Rate old
b36f7a
Diff Compare
torchvision-resnet50 2,244.14 2,243.36 0.03%
torchvision-resnet50_fp16 4,862.96 4,868.96 -0.12% 🔴
torchvision-alexnet 4,970.60 4,970.22 0.01%
torchvision-alexnet_fp16 25,517.53 25,556.66 -0.15%
torchvision-densenet121 1,814.76 1,815.17 -0.02%
torchvision-densenet121_fp16 3,275.00 3,269.64 0.16% 🔆
torchvision-inceptionv3 1,106.47 1,096.95 0.87% 🔆
torchvision-inceptionv3_fp16 1,959.04 1,977.99 -0.96% 🔴
torchvision-vgg16 894.21 893.98 0.03%
torchvision-vgg16_fp16 1,740.81 1,739.94 0.05% 🔆
cadene-inceptionv4 533.90 534.06 -0.03%
cadene-resnext64x4 578.64 578.14 0.09% 🔆
slim-mobilenet 6,541.84 6,543.53 -0.03%
slim-nasnetalarge 208.15 208.18 -0.02%
slim-resnet50v2 nan nan nan%
bert-mrpc-onnx 808.38 808.52 -0.02%
bert-mrpc-tf 314.51 313.00 0.48% 🔆
pytorch-examples-wlang-gru 426.44 437.40 -2.51% 🔴
pytorch-examples-wlang-lstm 384.18 385.42 -0.32% 🔴
torchvision-resnet50_1 515.66 516.08 -0.08%
torchvision-inceptionv3_1 304.54 305.28 -0.24% 🔴
torchvision-vgg16_1 462.86 463.21 -0.08% 🔴
cadene-dpn92_1 329.71 321.59 2.52% 🔆
cadene-resnext101_1 233.88 233.78 0.04%
slim-vgg16_1 63.99 64.01 -0.03% 🔴
slim-mobilenet_1 1,964.74 1,975.96 -0.57%
slim-inceptionv4_1 193.20 193.88 -0.35% 🔴
onnx-taau-downsample 255.61 255.94 -0.13% 🔴

This build is not recommended to merge 🔴

auto batch_sum = r.reduce(op::sum{}, 0, [&](auto x) {
return migraphx::convert<float>(migraphx::exp(x - c));
})(input);
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - c) / batch_sum; })(output, input);
Copy link
Collaborator

@TedThemistokleous TedThemistokleous Aug 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should still being using the max in the calculation, but I've seen an alternative to still save us from blowing up our exp() and prevents overflow and solve the correctness issue with Naans

calculate log(output) instead and then convert it back

This gives us the following:

//since
batch_sum = sum(exp(input))
max   = max(input)

batch_sum_sub_max = exp(input - max)
output = exp(x) / batch_sum =  exp(x-max) / batch_sum_sub_max
//take natural log of the sum
ln(output) =  (x - max) - log(batch_sum_sub_max)
//thus
output = exp( x - max - log(batch_sum_sub_max))

Refs:
https://stackoverflow.com/questions/9906136/implementation-of-a-softmax-activation-function-for-neural-networks
https://lingpipe-blog.com/2009/06/25/log-sum-of-exponentials/

The issue here now is will we have a speedup then without the max? Unless we have a fast way of doing the log() operator, or even exp. The nice thing here is we can split the final calculation once we have the input max, and sum though.

Not sure how to do that right now in our migraphx context and for this kernel, but I assume thats taken care of between the multiple inner statements?

Copy link
Collaborator

@TedThemistokleous TedThemistokleous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like with the flag we'll enable this on a case by case basis. I suppose this will give us a speedup on models that we know the input to the softmax wont be exceeding the (-83, 82) range we saw that blew up the calculation

Maybe they'll need to be logic ontop of this to analyze toggling this speedup later down the road.

@causten causten requested a review from umangyadav September 21, 2022 15:25
@causten causten merged commit a9a4740 into develop Oct 4, 2022
@causten causten deleted the fastsoftmax branch October 4, 2022 13:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants