Skip to content

Commit 35440b7

Browse files
authored
Patching bn inference (#2016)
Fixes BN inference. I'm stealing Ivan's changes from pytorch#85562 We are returning mini-batch stats during inference run in aten, this is not the right behavior and we should have changed that instead. But for the time being, let's change nvfuser behavior just to get CI green. Also, the extra set here to avoid trivial forwarding should be removed once #1995 is merged.
1 parent 0f9f0b4 commit 35440b7

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

torch/csrc/jit/codegen/cuda/ops/normalization.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -587,8 +587,10 @@ ForwardNormResult batch_norm(
587587
auto invstd_bcast = broadcast(unbiased_invstd, broadcast_mask);
588588

589589
// During inference, mean/invstd output are empty tensors
590-
mean = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
591-
invstd = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
590+
// on CPU, but not on CUDA. We need to make sure we have the same
591+
// behavior as with eager mode on CUDA.
592+
mean = set(running_mean);
593+
invstd = unbiased_invstd;
592594
y = mul(x_sub_mean, invstd_bcast);
593595
}
594596

@@ -840,8 +842,10 @@ ForwardNormResult instance_norm(
840842
broadcast(unbiased_invstd, channels_only_broadcast_mask);
841843

842844
// During inference, mean/invstd output are empty tensors
843-
mean = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
844-
invstd = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
845+
// on CPU, but not on CUDA. We need to make sure we have the same
846+
// behavior as with eager mode on CUDA.
847+
mean = set(running_mean);
848+
invstd = unbiased_invstd;
845849
y = mul(x_sub_mean, invstd_bcast);
846850
}
847851

0 commit comments

Comments
 (0)