Skip to content

Commit be6ad7d

Browse files
Will Fengfacebook-github-bot
Will Feng
authored andcommitted
Rename BatchNorm running_variance to running_var (pytorch#17371)
Summary: Currently there is a mismatch in naming between Python BatchNorm `running_var` and C++ BatchNorm `running_variance`, which causes JIT model parameters loading to fail (pytorch/vision#728 (comment)): ``` terminate called after throwing an instance of 'c10::Error' what(): No such serialized tensor 'running_variance' (read at /home/shahriar/Build/pytorch/torch/csrc/api/src/serialize/input-archive.cpp:27) frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x85 (0x7f2d92d32f95 in /usr/local/lib/libc10.so) frame #1: torch::serialize::InputArchive::read(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, at::Tensor&, bool) + 0xdeb (0x7f2d938551ab in /usr/local/lib/libtorch.so.1) frame #2: torch::nn::Module::load(torch::serialize::InputArchive&) + 0x98 (0x7f2d9381cd08 in /usr/local/lib/libtorch.so.1) frame #3: torch::nn::Module::load(torch::serialize::InputArchive&) + 0xf9 (0x7f2d9381cd69 in /usr/local/lib/libtorch.so.1) frame #4: torch::nn::Module::load(torch::serialize::InputArchive&) + 0xf9 (0x7f2d9381cd69 in /usr/local/lib/libtorch.so.1) frame #5: torch::nn::operator>>(torch::serialize::InputArchive&, std::shared_ptr<torch::nn::Module> const&) + 0x32 (0x7f2d9381c7b2 in /usr/local/lib/libtorch.so.1) frame #6: <unknown function> + 0x2b16c (0x5645f4d1916c in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest) frame #7: <unknown function> + 0x27a3c (0x5645f4d15a3c in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest) frame #8: <unknown function> + 0x2165c (0x5645f4d0f65c in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest) frame #9: <unknown function> + 0x1540b (0x5645f4d0340b in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest) frame #10: __libc_start_main + 0xf3 (0x7f2d051dd223 in /usr/lib/libc.so.6) frame #11: <unknown function> + 0x1381e (0x5645f4d0181e in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest) ``` Renaming C++ BatchNorm `running_variance` to `running_var` should fix this problem. This is a BC-breaking change, but it should be easy for end user to rename `running_variance` to `running_var` in their call sites. Pull Request resolved: pytorch#17371 Reviewed By: goldsborough Differential Revision: D14172775 Pulled By: yf225 fbshipit-source-id: b9d3729ec79272a8084269756f28a8f7c4dd16b6
1 parent 562fa55 commit be6ad7d

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

test/cpp/api/modules.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,9 @@ TEST_F(ModulesTest, BatchNormStateful) {
248248
ASSERT_EQ(bn->running_mean.dim(), 1);
249249
ASSERT_EQ(bn->running_mean.size(0), 5);
250250

251-
ASSERT_TRUE(bn->running_variance.defined());
252-
ASSERT_EQ(bn->running_variance.dim(), 1);
253-
ASSERT_EQ(bn->running_variance.size(0), 5);
251+
ASSERT_TRUE(bn->running_var.defined());
252+
ASSERT_EQ(bn->running_var.dim(), 1);
253+
ASSERT_EQ(bn->running_var.size(0), 5);
254254

255255
// Is affine by default.
256256
ASSERT_TRUE(bn->options.affine());
@@ -267,7 +267,7 @@ TEST_F(ModulesTest, BatchNormStateless) {
267267
BatchNorm bn(BatchNormOptions(5).stateful(false).affine(false));
268268

269269
ASSERT_FALSE(bn->running_mean.defined());
270-
ASSERT_FALSE(bn->running_variance.defined());
270+
ASSERT_FALSE(bn->running_var.defined());
271271
ASSERT_FALSE(bn->weight.defined());
272272
ASSERT_FALSE(bn->bias.defined());
273273

torch/csrc/api/include/torch/nn/modules/batchnorm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class TORCH_API BatchNormImpl : public torch::nn::Cloneable<BatchNormImpl> {
8888

8989
/// The running variance.
9090
/// Only defined if the `stateful` option was `true` upon construction.
91-
Tensor running_variance;
91+
Tensor running_var;
9292
};
9393

9494
/// A `ModuleHolder` subclass for `BatchNormImpl`.

torch/csrc/api/src/nn/modules/batchnorm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ void BatchNormImpl::reset() {
2828
if (options.stateful_) {
2929
running_mean =
3030
register_buffer("running_mean", torch::zeros({options.features_}));
31-
running_variance =
32-
register_buffer("running_variance", torch::ones({options.features_}));
31+
running_var =
32+
register_buffer("running_var", torch::ones({options.features_}));
3333
}
3434
}
3535

@@ -47,7 +47,7 @@ Tensor BatchNormImpl::forward(const Tensor& input) {
4747
"Calling BatchNorm::forward is only permitted when "
4848
"the 'stateful' option is true (was false). "
4949
"Use BatchNorm::pure_forward instead.");
50-
return pure_forward(input, running_mean, running_variance);
50+
return pure_forward(input, running_mean, running_var);
5151
}
5252

5353
Tensor BatchNormImpl::pure_forward(

0 commit comments

Comments
 (0)