-
Notifications
You must be signed in to change notification settings - Fork 24.4k
Move the CUDA implementation of trunc to ATen. #25423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
void trunc_kernel_cuda(TensorIterator& iter) { | ||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "trunc_cuda", [&]() { | ||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { | ||
return std::trunc(a); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure there is no need to overload to call truncf
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it's overloaded https://en.cppreference.com/w/cpp/numeric/math/trunc (also search for VSTD::trunc
in crt/math_functions.h
in the cuda include dir)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you are worried, here's some evidence. Compile the following program using
nvcc -ptx -src-in-ptx -arch=sm_60 test.cu
#include <cuda_runtime.h>
__global__ void test_trunc_f(float x, float& x2) {
x2 = std::trunc(x);
}
__global__ void test_trunc_d(double x, double& x2) {
x2 = std::trunc(x);
}
__global__ void test_trunc_f_d(float x, float& x2) {
x2 = truncf(x);
}
The output shows (Note that the first and third functions are compiled to the same asm code)
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-26907403
// Cuda compilation tools, release 10.1, V10.1.243
// Based on LLVM 3.4svn
//
.version 6.4
.target sm_60
.address_size 64
// .globl _Z12test_trunc_ffRf
.visible .entry _Z12test_trunc_ffRf(
.param .f32 _Z12test_trunc_ffRf_param_0,
.param .u64 _Z12test_trunc_ffRf_param_1
)
{
.reg .f32 %f<3>;
.reg .b64 %rd<3>;
ld.param.f32 %f1, [_Z12test_trunc_ffRf_param_0];
ld.param.u64 %rd1, [_Z12test_trunc_ffRf_param_1];
cvta.to.global.u64 %rd2, %rd1;
cvt.rzi.f32.f32 %f2, %f1;
st.global.f32 [%rd2], %f2;
ret;
}
// .globl _Z12test_trunc_ddRd
.visible .entry _Z12test_trunc_ddRd(
.param .f64 _Z12test_trunc_ddRd_param_0,
.param .u64 _Z12test_trunc_ddRd_param_1
)
{
.reg .f64 %fd<3>;
.reg .b64 %rd<3>;
ld.param.f64 %fd1, [_Z12test_trunc_ddRd_param_0];
ld.param.u64 %rd1, [_Z12test_trunc_ddRd_param_1];
cvta.to.global.u64 %rd2, %rd1;
cvt.rzi.f64.f64 %fd2, %fd1;
st.global.f64 [%rd2], %fd2;
ret;
}
// .globl _Z14test_trunc_f_dfRf
.visible .entry _Z14test_trunc_f_dfRf(
.param .f32 _Z14test_trunc_f_dfRf_param_0,
.param .u64 _Z14test_trunc_f_dfRf_param_1
)
{
.reg .f32 %f<3>;
.reg .b64 %rd<3>;
ld.param.f32 %f1, [_Z14test_trunc_f_dfRf_param_0];
ld.param.u64 %rd1, [_Z14test_trunc_f_dfRf_param_1];
cvta.to.global.u64 %rd2, %rd1;
cvt.rzi.f32.f32 %f2, %f1;
st.global.f32 [%rd2], %f2;
ret;
}
Please fix |
Fix #24650 Differential Revision: [D17397489](https://our.internmc.facebook.com/intern/diff/D17397489)
@VitalyFedyunin Updated. Should have fixed it now |
Fix #24650 Differential Revision: [D17397489](https://our.internmc.facebook.com/intern/diff/D17397489)
Fix #24650 Differential Revision: [D17397489](https://our.internmc.facebook.com/intern/diff/D17397489)
Fix #24650 Differential Revision: [D17397489](https://our.internmc.facebook.com/intern/diff/D17397489)
"All checks have passed" super suspicious ;) |
@VitalyFedyunin A friendly reminder that the previous two merged commits in this stack seem to have put the incorrect authorship: f55a9da |
Please rebase ( contention on the |
Fix #24650 Differential Revision: [D17397489](https://our.internmc.facebook.com/intern/diff/D17397489)
Done! |
Summary: Pull Request resolved: pytorch/pytorch#25423 Fix #24650 Test Plan: Imported from OSS Differential Revision: D17397489 Pulled By: VitalyFedyunin fbshipit-source-id: 933f915a44ff9b7803ddb2708bf0e723433ee0b6
@VitalyFedyunin merged this pull request in 7bdc0c1. |
Thanks! |
Stack from ghstack:
sign
using the helper. #25592 Simplify operatorsign
using the helper.Fix #24650
Differential Revision: D17397489