Skip to content

Add variance_mean composite function using Welford #1907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions torch/csrc/jit/codegen/cuda/ops/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,64 @@ TensorView* variance(
return y;
}

TORCH_CUDA_CU_API VarMeanResult variance_mean(
TensorView* x,
const std::vector<int>& dims,
int64_t correction,
bool keepdim) {
TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");

TORCH_CHECK(
correction >= 0, "correction must be non-negative, but got ", correction);

// There are compilation errors for half precision
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we consider just promote math to float for reduced precision then? that seems to be an universal way used by all others computing variance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I handle type promotion on the Python side.

auto dtype = x->getDataType().value();
TORCH_CHECK(
!(dtype == DataType::Half || dtype == DataType::BFloat16),
"variance_mean is not supported for ",
dtype,
" please upcast to float");

if (isComplexType(x->getDataType().value())) {
// There are compilation errors:
// __tmp_kernel1.cu(6727): error: namespace "CudaCodeGen::std" has no member
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't sound very right to me. cc'ing @zasdfgbnm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely not the correct behavior, but I am not surprised at all😜. I got moved to more important tasks before finishing complex support, so currently, complex support in nvfuser is very broken. Please consider complex as not supported when doing other works.

// "imagf"
// __tmp_kernel1.cu(6753): error: namespace "CudaCodeGen::std" has no member
// "realf"
TORCH_CHECK(false, "var_mean is not supported for complex types.");
auto out_real = variance_mean(real(x), dims, correction, keepdim);
auto out_imag = variance_mean(imag(x), dims, correction, keepdim);
// variance of a complex tensor is the sum of real and imaginary variances
// and is real mean of a complex tensor is complex complex(out_real.mean,
// out_imag.mean) It seems construction of a complex tensor from two real
// tensors is not supported yet
return {add(out_real.var, out_imag.var), nullptr};
}

const int kNumberOfDims =
TensorDomain::noReductions(x->getMaybeRFactorDomain()).size();
auto num_features = numFeatures(x, dims, kNumberOfDims);
if (correction > 0) {
num_features =
sub(num_features, IrBuilder::create<Int>(x->container(), correction));
}

auto welford_out = Welford(x, dims);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would hope that we can have variance to also use welford, so we can unify the interface for variance and variance_mean.
But WelfordOp appears to be a single op in codegen, does this means the generated kernel would still write mean to a register somewhere, even though they are not used at all? That we probably not want.

cc'ing @naoyam @shmsong on this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can of course try using

TensorView* variance(
    TensorView* x,
    const std::vector<int>& dims,
    int64_t correction,
    bool keepdim) {
  auto var_mean = variance_mean(x, dims, correction, keepdim);
  return var_mean.var;
}

If it positively impacts the performance. That's how ATen implements this function and it would justify var being a prim in PrimTorch.

auto mean = welford_out.avg;
auto var = mul(welford_out.var_sum, reciprocal(num_features));

if (keepdim) {
std::vector<bool> is_broadcast(kNumberOfDims, false);
for (auto dim : dims) {
is_broadcast[dim] = true;
}
var = broadcast(var, is_broadcast);
mean = broadcast(mean, is_broadcast);
}

return {var, mean};
}

TensorView* standard_deviation(
TensorView* x,
const std::vector<int>& dims,
Expand Down
11 changes: 11 additions & 0 deletions torch/csrc/jit/codegen/cuda/ops/normalization.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ struct BackwardRMSNormResult {
TensorView* grad_weight = nullptr;
};

struct VarMeanResult {
TensorView* var = nullptr;
TensorView* mean = nullptr;
};

TORCH_CUDA_CU_API TensorView* mean(
TensorView* x,
const std::vector<int>& dims,
Expand All @@ -55,6 +60,12 @@ TORCH_CUDA_CU_API TensorView* variance(
int64_t correction,
bool keepdim);

TORCH_CUDA_CU_API VarMeanResult variance_mean(
TensorView* x,
const std::vector<int>& dims,
int64_t correction,
bool keepdim);

TORCH_CUDA_CU_API TensorView* standard_deviation(
TensorView* x,
const std::vector<int>& dims,
Expand Down
44 changes: 44 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13150,6 +13150,50 @@ TEST_F(NVFuserTest, FusionWelfordShmoo_CUDA) {
}
}

namespace {
void testVarMean(at::ScalarType dtype, int correction, bool keepdim) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

int M = 64, N = 128;

auto tv0 = makeSymbolicTensor(2, aten_to_data_type(dtype));
fusion->addInput(tv0);
auto tvs = variance_mean(tv0, {1}, correction, keepdim);
auto tv_mean = tvs.mean;
auto tv_var = tvs.var;
fusion->addOutput(tv_var);
fusion->addOutput(tv_mean);

auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
at::manual_seed(0);
at::Tensor t0 = at::randn({M, N}, options);

FusionExecutorCache executor_cache(std::move(fusion));
auto outputs = executor_cache.runFusionWithInputs({t0});

auto at_var_mean = at::var_mean(t0, {1}, correction, keepdim);
std::vector<at::Tensor> aten_outputs = {
std::get<0>(at_var_mean), std::get<1>(at_var_mean)};

testValidate(
executor_cache.fusion(), outputs, {t0}, aten_outputs, __LINE__, __FILE__);
}
} // namespace

TEST_F(NVFuserTest, FusionVarMean_CUDA) {
std::vector<at::ScalarType> dtypes = {at::kFloat, at::kDouble};
std::vector<int> corrections = {0, 1};
std::vector<bool> keepdims = {false, true};
for (auto correction : corrections) {
for (auto keepdim : keepdims) {
for (auto dtype : dtypes) {
testVarMean(dtype, correction, keepdim);
}
}
}
}

TEST_F(NVFuserTest, FusionSimpleGemmTransposed_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down