-
Notifications
You must be signed in to change notification settings - Fork 7
Add variance_mean composite function using Welford #1907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,6 +69,64 @@ TensorView* variance( | |
return y; | ||
} | ||
|
||
TORCH_CUDA_CU_API VarMeanResult variance_mean( | ||
TensorView* x, | ||
const std::vector<int>& dims, | ||
int64_t correction, | ||
bool keepdim) { | ||
TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid."); | ||
|
||
TORCH_CHECK( | ||
correction >= 0, "correction must be non-negative, but got ", correction); | ||
|
||
// There are compilation errors for half precision | ||
auto dtype = x->getDataType().value(); | ||
TORCH_CHECK( | ||
!(dtype == DataType::Half || dtype == DataType::BFloat16), | ||
"variance_mean is not supported for ", | ||
dtype, | ||
" please upcast to float"); | ||
|
||
if (isComplexType(x->getDataType().value())) { | ||
// There are compilation errors: | ||
// __tmp_kernel1.cu(6727): error: namespace "CudaCodeGen::std" has no member | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't sound very right to me. cc'ing @zasdfgbnm There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is definitely not the correct behavior, but I am not surprised at all😜. I got moved to more important tasks before finishing complex support, so currently, complex support in nvfuser is very broken. Please consider complex as not supported when doing other works. |
||
// "imagf" | ||
// __tmp_kernel1.cu(6753): error: namespace "CudaCodeGen::std" has no member | ||
// "realf" | ||
TORCH_CHECK(false, "var_mean is not supported for complex types."); | ||
auto out_real = variance_mean(real(x), dims, correction, keepdim); | ||
auto out_imag = variance_mean(imag(x), dims, correction, keepdim); | ||
// variance of a complex tensor is the sum of real and imaginary variances | ||
// and is real mean of a complex tensor is complex complex(out_real.mean, | ||
// out_imag.mean) It seems construction of a complex tensor from two real | ||
// tensors is not supported yet | ||
return {add(out_real.var, out_imag.var), nullptr}; | ||
} | ||
|
||
const int kNumberOfDims = | ||
TensorDomain::noReductions(x->getMaybeRFactorDomain()).size(); | ||
auto num_features = numFeatures(x, dims, kNumberOfDims); | ||
if (correction > 0) { | ||
num_features = | ||
sub(num_features, IrBuilder::create<Int>(x->container(), correction)); | ||
} | ||
|
||
auto welford_out = Welford(x, dims); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would hope that we can have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can of course try using TensorView* variance(
TensorView* x,
const std::vector<int>& dims,
int64_t correction,
bool keepdim) {
auto var_mean = variance_mean(x, dims, correction, keepdim);
return var_mean.var;
} If it positively impacts the performance. That's how ATen implements this function and it would justify |
||
auto mean = welford_out.avg; | ||
auto var = mul(welford_out.var_sum, reciprocal(num_features)); | ||
|
||
if (keepdim) { | ||
std::vector<bool> is_broadcast(kNumberOfDims, false); | ||
for (auto dim : dims) { | ||
is_broadcast[dim] = true; | ||
} | ||
var = broadcast(var, is_broadcast); | ||
mean = broadcast(mean, is_broadcast); | ||
} | ||
|
||
return {var, mean}; | ||
} | ||
|
||
TensorView* standard_deviation( | ||
TensorView* x, | ||
const std::vector<int>& dims, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we consider just promote math to float for reduced precision then? that seems to be an universal way used by all others computing variance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I handle type promotion on the Python side.