Skip to content

Move the CUDA implementation of trunc to ATen. #25423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from

Conversation

xuhdev
Copy link
Collaborator

@xuhdev xuhdev commented Aug 29, 2019

Stack from ghstack:

Fix #24650

Differential Revision: D17397489

void trunc_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "trunc_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return std::trunc(a);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure there is no need to overload to call truncf ?

Copy link
Collaborator Author

@xuhdev xuhdev Aug 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's overloaded https://en.cppreference.com/w/cpp/numeric/math/trunc (also search for VSTD::trunc in crt/math_functions.h in the cuda include dir)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are worried, here's some evidence. Compile the following program using

nvcc -ptx -src-in-ptx -arch=sm_60 test.cu

#include <cuda_runtime.h>

__global__ void test_trunc_f(float x,  float& x2) {
  x2 = std::trunc(x);
}

__global__ void test_trunc_d(double x,  double& x2) {
  x2 = std::trunc(x);
}

__global__ void test_trunc_f_d(float x,  float& x2) {
  x2 = truncf(x);
}

The output shows (Note that the first and third functions are compiled to the same asm code)

//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-26907403
// Cuda compilation tools, release 10.1, V10.1.243
// Based on LLVM 3.4svn
//

.version 6.4
.target sm_60
.address_size 64

	// .globl	_Z12test_trunc_ffRf

.visible .entry _Z12test_trunc_ffRf(
	.param .f32 _Z12test_trunc_ffRf_param_0,
	.param .u64 _Z12test_trunc_ffRf_param_1
)
{
	.reg .f32 	%f<3>;
	.reg .b64 	%rd<3>;


	ld.param.f32 	%f1, [_Z12test_trunc_ffRf_param_0];
	ld.param.u64 	%rd1, [_Z12test_trunc_ffRf_param_1];
	cvta.to.global.u64 	%rd2, %rd1;
	cvt.rzi.f32.f32	%f2, %f1;
	st.global.f32 	[%rd2], %f2;
	ret;
}

	// .globl	_Z12test_trunc_ddRd
.visible .entry _Z12test_trunc_ddRd(
	.param .f64 _Z12test_trunc_ddRd_param_0,
	.param .u64 _Z12test_trunc_ddRd_param_1
)
{
	.reg .f64 	%fd<3>;
	.reg .b64 	%rd<3>;


	ld.param.f64 	%fd1, [_Z12test_trunc_ddRd_param_0];
	ld.param.u64 	%rd1, [_Z12test_trunc_ddRd_param_1];
	cvta.to.global.u64 	%rd2, %rd1;
	cvt.rzi.f64.f64	%fd2, %fd1;
	st.global.f64 	[%rd2], %fd2;
	ret;
}

	// .globl	_Z14test_trunc_f_dfRf
.visible .entry _Z14test_trunc_f_dfRf(
	.param .f32 _Z14test_trunc_f_dfRf_param_0,
	.param .u64 _Z14test_trunc_f_dfRf_param_1
)
{
	.reg .f32 	%f<3>;
	.reg .b64 	%rd<3>;


	ld.param.f32 	%f1, [_Z14test_trunc_f_dfRf_param_0];
	ld.param.u64 	%rd1, [_Z14test_trunc_f_dfRf_param_1];
	cvta.to.global.u64 	%rd2, %rd1;
	cvt.rzi.f32.f32	%f2, %f1;
	st.global.f32 	[%rd2], %f2;
	ret;
}

xuhdev added a commit that referenced this pull request Sep 3, 2019
Fix #24650

ghstack-source-id: c8226da
Pull Request resolved: #25423
VitalyFedyunin
VitalyFedyunin previously approved these changes Sep 16, 2019
@VitalyFedyunin VitalyFedyunin dismissed their stale review September 16, 2019 15:36

rocm failure looks reasonable

@VitalyFedyunin
Copy link
Contributor

Please fix pr/py2-clang7-rocmdeb-ubuntu16.04

@xuhdev
Copy link
Collaborator Author

xuhdev commented Sep 16, 2019

@VitalyFedyunin Updated. Should have fixed it now

@VitalyFedyunin
Copy link
Contributor

"All checks have passed" super suspicious ;)

@xuhdev
Copy link
Collaborator Author

xuhdev commented Sep 21, 2019

@VitalyFedyunin A friendly reminder that the previous two merged commits in this stack seem to have put the incorrect authorship: f55a9da

@VitalyFedyunin
Copy link
Contributor

Please rebase ( contention on the UnaryOps.cpp is way to high )

@xuhdev
Copy link
Collaborator Author

xuhdev commented Sep 23, 2019

Done!

zdevito pushed a commit to zdevito/ATen that referenced this pull request Sep 24, 2019
Summary:
Pull Request resolved: pytorch/pytorch#25423

Fix #24650

Test Plan: Imported from OSS

Differential Revision: D17397489

Pulled By: VitalyFedyunin

fbshipit-source-id: 933f915a44ff9b7803ddb2708bf0e723433ee0b6
@facebook-github-bot
Copy link
Contributor

@VitalyFedyunin merged this pull request in 7bdc0c1.

@xuhdev
Copy link
Collaborator Author

xuhdev commented Sep 24, 2019

Thanks!

@xuhdev xuhdev deleted the gh/xuhdev/33/head branch September 24, 2019 17:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen open source
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants