-
Notifications
You must be signed in to change notification settings - Fork 7
Add support for normal distribution RNG #2171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -46,7 +46,7 @@ __device__ uint4 philox( | |||
|
|||
__device__ float uniformf(unsigned int x) { | |||
constexpr float kRanInvM32 = 2.3283064e-10f; // Inverse of 2^32. | |||
float result = x * kRanInvM32; | |||
float result = x * kRanInvM32 + kRanInvM32 / 2.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to ensure the implementation is in consistent with curand. The extra term is there in cuda 11.6 curand version.
The box-muller normal RNG uses uniform RNG first from this function. Without this extra term, the normal RNG would have large numerical differences from curand results.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the uniform tests pass on both pre-11.6 and post-11.6?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests passed locally on my cuda 11.6 environment. I'll start a build & test on cuda 11.8 environment as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've checked that the new tests for Normal and Uniform pass on cuda 11.8 A100 as well.
build/bin/test_jit --gtest_filter='*FusionNormal*:*FusionUniform*'
@@ -157,6 +157,11 @@ TORCH_CUDA_CU_API TensorView* uniform( | |||
Val* low, | |||
Val* high, | |||
DataType dtype); | |||
TORCH_CUDA_CU_API TensorView* normal( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessarily in this PR, but we should also add randn
and randn_like
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly good now. Left some final comments.
Co-authored-by: Gao, Xiang <[email protected]>
@xwang233 Do you have permission to merge PRs? |
I don't see a merge button on this PR. |
I will merge this PR for you. ping @csarofeen for permission. |
See also #1986