Skip to content

Commit 1f8ad1b

Browse files
committed
Rename kTraining -> training
1 parent 3adc1b7 commit 1f8ad1b

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,14 +1353,14 @@ struct BatchNormOpRecord : RecordFunctor {
13531353
BatchNormOpRecord(
13541354
std::vector<State> args,
13551355
std::vector<State> outputs,
1356-
bool kTraining,
1356+
bool training,
13571357
bool channels_last)
13581358
: RecordFunctor(
13591359
std::move(args),
13601360
std::move(outputs),
13611361
"ops.batch_norm",
13621362
RecordType::BatchNormOp),
1363-
kTraining_(kTraining),
1363+
training_(training),
13641364
channels_last_(channels_last) {}
13651365
virtual ~BatchNormOpRecord() = default;
13661366
virtual RecordFunctor* clone() final {
@@ -1371,15 +1371,15 @@ struct BatchNormOpRecord : RecordFunctor {
13711371
auto result = false;
13721372
if (auto child_ptr = dynamic_cast<const BatchNormOpRecord*>(&other)) {
13731373
result = RecordFunctor::operator==(other);
1374-
result = result && (kTraining_ == child_ptr->kTraining_);
1374+
result = result && (training_ == child_ptr->training_);
13751375
result = result && (channels_last_ == child_ptr->channels_last_);
13761376
}
13771377
return result;
13781378
}
13791379

13801380
virtual size_t hash() const final {
13811381
auto result = RecordFunctor::hash();
1382-
return result | (static_cast<size_t>(kTraining_) << 28) |
1382+
return result | (static_cast<size_t>(training_) << 28) |
13831383
(static_cast<size_t>(channels_last_) << 29);
13841384
}
13851385

@@ -1399,7 +1399,7 @@ struct BatchNormOpRecord : RecordFunctor {
13991399
bias,
14001400
running_mean,
14011401
running_var,
1402-
kTraining_,
1402+
training_,
14031403
momentum,
14041404
eps,
14051405
channels_last_);
@@ -1409,7 +1409,7 @@ struct BatchNormOpRecord : RecordFunctor {
14091409
}
14101410

14111411
private:
1412-
bool kTraining_;
1412+
bool training_;
14131413
bool channels_last_;
14141414
};
14151415

torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,7 +1277,7 @@ void initNvFuserPythonBindings(PyObject* module) {
12771277
nvfuser::Tensor bias,
12781278
nvfuser::Tensor running_mean,
12791279
nvfuser::Tensor running_var,
1280-
bool kTraining,
1280+
bool training,
12811281
nvfuser::Scalar momentum,
12821282
nvfuser::Scalar eps,
12831283
bool channels_last) -> decltype(auto) {
@@ -1297,7 +1297,7 @@ void initNvFuserPythonBindings(PyObject* module) {
12971297
{fd->recordingState(output()),
12981298
fd->recordingState(mean()),
12991299
fd->recordingState(invstd())},
1300-
kTraining,
1300+
training,
13011301
channels_last));
13021302
return std::make_tuple(output, mean, invstd);
13031303
},
@@ -1306,7 +1306,7 @@ void initNvFuserPythonBindings(PyObject* module) {
13061306
py::arg("bias").none(true),
13071307
py::arg("running_mean").none(true),
13081308
py::arg("running_var").none(true),
1309-
py::arg("kTraining"),
1309+
py::arg("training"),
13101310
py::arg("momentum"),
13111311
py::arg("eps"),
13121312
py::arg("channels_last") = false,

0 commit comments

Comments
 (0)