-
Notifications
You must be signed in to change notification settings - Fork 68
rocFFT integration #139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rocFFT integration #139
Changes from all commits
84a985e
5e710fc
c40247a
777c095
1322aa1
6098f41
a02c34f
3bd1512
8f8b653
b67e1ee
103e534
8709047
1f0ab4b
6b85570
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -149,7 +149,11 @@ class CuFFTConfig { | |
// TODO: Figure out why windows fails to compile | ||
// at::optional<std::vector<long long int>> inembed_opt = at::nullopt; | ||
// Then move the following to a helper function. | ||
#ifdef __HIP_PLATFORM_HCC__ | ||
std::vector<int> inembed(signal_ndim); | ||
#else | ||
std::vector<long long int> inembed(signal_ndim); | ||
#endif | ||
if (!clone_input) { | ||
auto istrides = input.strides(); | ||
auto last_istride = istrides[signal_ndim]; | ||
|
@@ -192,6 +196,37 @@ class CuFFTConfig { | |
inembed.begin()); // begin of output | ||
} | ||
|
||
#ifdef __HIP_PLATFORM_HCC__ | ||
|
||
hipfftType exec_type; | ||
if (input.type().scalarType() == ScalarType::Float) { | ||
if (complex_input && complex_output) { | ||
exec_type = HIPFFT_C2C; | ||
} else if (complex_input && !complex_output) { | ||
exec_type = HIPFFT_C2R; | ||
} else if (!complex_input && complex_output) { | ||
exec_type = HIPFFT_R2C; | ||
} else { | ||
throw std::runtime_error("hipFFT doesn't support r2r (float)"); | ||
} | ||
} else if (input.type().scalarType() == ScalarType::Double) { | ||
if (complex_input && complex_output) { | ||
exec_type = HIPFFT_Z2Z; | ||
} else if (complex_input && !complex_output) { | ||
exec_type = HIPFFT_Z2D; | ||
} else if (!complex_input && complex_output) { | ||
exec_type = HIPFFT_D2Z; | ||
} else { | ||
throw std::runtime_error("hipFFT doesn't support r2r (double)"); | ||
} | ||
} else { | ||
std::ostringstream ss; | ||
ss << "hipFFT doesn't support tensor of type: " | ||
<< at::toString(input.type().scalarType()); | ||
throw std::runtime_error(ss.str()); | ||
} | ||
|
||
#else | ||
cudaDataType itype, otype, exec_type; | ||
if (input.type().scalarType() == ScalarType::Float) { | ||
itype = complex_input ? CUDA_C_32F : CUDA_R_32F; | ||
|
@@ -211,6 +246,7 @@ class CuFFTConfig { | |
<< at::toString(input.type().scalarType()); | ||
throw std::runtime_error(ss.str()); | ||
} | ||
#endif | ||
|
||
// create plan | ||
auto raw_plan_ptr = new cufftHandle(); | ||
|
@@ -229,10 +265,18 @@ class CuFFTConfig { | |
// by assuming base_istride = base_ostride = 1. | ||
// | ||
// See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu. | ||
#ifdef __HIP_PLATFORM_HCC__ | ||
int sizes = *signal_sizes.data(); | ||
CUFFT_CHECK(hipfftMakePlanMany(plan(), signal_ndim, &sizes, | ||
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, | ||
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, | ||
exec_type, batch, &ws_size_t)); | ||
#else | ||
CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(), | ||
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype, | ||
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype, | ||
batch, &ws_size_t, exec_type)); | ||
#endif | ||
} else { | ||
// set idist (stride at batch dim) | ||
// set base_istride (stride at innermost dim of signal) | ||
|
@@ -254,6 +298,19 @@ class CuFFTConfig { | |
} | ||
|
||
// set odist, onembed, base_ostride | ||
#ifdef __HIP_PLATFORM_HCC__ | ||
int odist = at::prod_intlist(output_sizes.slice(1, signal_ndim)); | ||
std::vector<int> onembed(output_sizes.data() + 1, output_sizes.data() + signal_ndim + 1); | ||
int base_ostride = 1; | ||
|
||
int sizes = *signal_sizes.data(); | ||
int istride = base_istride; | ||
int iidist = idist; | ||
CUFFT_CHECK(hipfftMakePlanMany(plan(), signal_ndim, &sizes, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't really need to create a new variable. Can simply pass signal_sizes.data(). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually we do. signal_sizes.data() returns a |
||
inembed.data(), istride, iidist, | ||
onembed.data(), base_ostride, odist, | ||
exec_type, batch, &ws_size_t)); | ||
#else | ||
long long int odist = at::prod_intlist(output_sizes.slice(1, signal_ndim)); | ||
std::vector<long long int> onembed(output_sizes.data() + 1, output_sizes.data() + signal_ndim + 1); | ||
long long int base_ostride = 1; | ||
|
@@ -262,11 +319,16 @@ class CuFFTConfig { | |
inembed.data(), base_istride, idist, itype, | ||
onembed.data(), base_ostride, odist, otype, | ||
batch, &ws_size_t, exec_type)); | ||
} | ||
#endif | ||
} | ||
ws_size = static_cast<int64_t>(ws_size_t); | ||
} | ||
|
||
#ifdef __HIP_PLATFORM_HCC__ | ||
cufftHandle &plan() const { return *plan_ptr.get(); } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why remove the constness here? The signatures for hipfftCreate & cufftCreate are identical. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because the signatures for MakeMany are not. |
||
#else | ||
const cufftHandle &plan() const { return *plan_ptr.get(); } | ||
#endif | ||
|
||
bool should_clone_input() const { return clone_input; } | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -190,8 +190,45 @@ static inline Tensor _run_cufft( | |
CUFFT_CHECK(cufftSetWorkArea(plan, ws.data_ptr())); | ||
|
||
// run | ||
#ifdef __HIP_PLATFORM_HCC__ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fine for now but would be good to file an issue under the rocFFT to further extend their API for a cufftXtExec-esque call. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Already done.... |
||
if (input.type().scalarType() == ScalarType::Float) { | ||
if (complex_input && complex_output) { | ||
CUFFT_CHECK(hipfftExecC2C(plan, static_cast<hipfftComplex*>(input.data_ptr()), | ||
static_cast<hipfftComplex*>(output.data_ptr()), | ||
inverse ? HIPFFT_BACKWARD : HIPFFT_FORWARD)); | ||
} else if (complex_input && !complex_output) { | ||
CUFFT_CHECK(hipfftExecC2R(plan, static_cast<hipfftComplex*>(input.data_ptr()), | ||
static_cast<hipfftReal*>(output.data_ptr()))); | ||
} else if (!complex_input && complex_output) { | ||
CUFFT_CHECK(hipfftExecR2C(plan, static_cast<hipfftReal*>(input.data_ptr()), | ||
static_cast<hipfftComplex*>(output.data_ptr()))); | ||
} else { | ||
throw std::runtime_error("hipFFT doesn't support r2r (float)"); | ||
} | ||
} else if (input.type().scalarType() == ScalarType::Double) { | ||
if (complex_input && complex_output) { | ||
CUFFT_CHECK(hipfftExecZ2Z(plan, static_cast<hipfftDoubleComplex*>(input.data_ptr()), | ||
static_cast<hipfftDoubleComplex*>(output.data_ptr()), | ||
inverse ? HIPFFT_BACKWARD : HIPFFT_FORWARD)); | ||
} else if (complex_input && !complex_output) { | ||
CUFFT_CHECK(hipfftExecZ2D(plan, static_cast<hipfftDoubleComplex*>(input.data_ptr()), | ||
static_cast<hipfftDoubleReal*>(output.data_ptr()))); | ||
} else if (!complex_input && complex_output) { | ||
CUFFT_CHECK(hipfftExecD2Z(plan, static_cast<hipfftDoubleReal*>(input.data_ptr()), | ||
static_cast<hipfftDoubleComplex*>(output.data_ptr()))); | ||
} else { | ||
throw std::runtime_error("hipFFT doesn't support r2r (double)"); | ||
} | ||
} else { | ||
std::ostringstream ss; | ||
ss << "hipFFT doesn't support tensor of type: " | ||
<< at::toString(input.type().scalarType()); | ||
throw std::runtime_error(ss.str()); | ||
} | ||
#else | ||
CUFFT_CHECK(cufftXtExec(plan, input.data_ptr(), output.data_ptr(), | ||
inverse ? CUFFT_INVERSE : CUFFT_FORWARD)); | ||
#endif | ||
|
||
// rescale if needed by normalized flag or inverse transform | ||
auto size_last_signal_dim = checked_signal_sizes[signal_ndim - 1]; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't really need to create a new variable. Can simply pass signal_sizes.data().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually we do. signal_sizes.data() returns a
long long*
and rocFFT needs aint*