Skip to content

Add support for normal distribution RNG #2171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Nov 18, 2022
Merged

Conversation

xwang233
Copy link

@xwang233 xwang233 commented Nov 9, 2022

See also #1986

normal(shape, mean, std, dtype)

@xwang233
Copy link
Author

xwang233 commented Nov 9, 2022

@zasdfgbnm

@zasdfgbnm zasdfgbnm self-requested a review November 9, 2022 17:18
@@ -46,7 +46,7 @@ __device__ uint4 philox(

__device__ float uniformf(unsigned int x) {
constexpr float kRanInvM32 = 2.3283064e-10f; // Inverse of 2^32.
float result = x * kRanInvM32;
float result = x * kRanInvM32 + kRanInvM32 / 2.0f;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to ensure the implementation is in consistent with curand. The extra term is there in cuda 11.6 curand version.

The box-muller normal RNG uses uniform RNG first from this function. Without this extra term, the normal RNG would have large numerical differences from curand results.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the uniform tests pass on both pre-11.6 and post-11.6?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests passed locally on my cuda 11.6 environment. I'll start a build & test on cuda 11.8 environment as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've checked that the new tests for Normal and Uniform pass on cuda 11.8 A100 as well.

build/bin/test_jit --gtest_filter='*FusionNormal*:*FusionUniform*'

@@ -157,6 +157,11 @@ TORCH_CUDA_CU_API TensorView* uniform(
Val* low,
Val* high,
DataType dtype);
TORCH_CUDA_CU_API TensorView* normal(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily in this PR, but we should also add randn and randn_like

Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly good now. Left some final comments.

@zasdfgbnm
Copy link
Collaborator

@xwang233 Do you have permission to merge PRs?

@xwang233
Copy link
Author

I don't see a merge button on this PR.

@zasdfgbnm
Copy link
Collaborator

I will merge this PR for you. ping @csarofeen for permission.

@zasdfgbnm zasdfgbnm merged commit f7f8c3c into csarofeen:devel Nov 18, 2022
@xwang233 xwang233 mentioned this pull request Jan 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants