Skip to content

Reduce variance of classification references evaluation #4609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions references/classification/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import os
import time
import warnings

import presets
import torch
Expand Down Expand Up @@ -54,6 +55,8 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = f"Test: {log_suffix}"

num_processed_samples = 0
with torch.no_grad():
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image = image.to(device, non_blocking=True)
Expand All @@ -68,7 +71,23 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="
metric_logger.update(loss=loss.item())
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
num_processed_samples += batch_size
# gather the stats from all processes

num_processed_samples = utils.reduce_across_processes(num_processed_samples)
if (
hasattr(data_loader.dataset, "__len__")
and len(data_loader.dataset) != num_processed_samples
and torch.distributed.get_rank() == 0
):
# See FIXME above
warnings.warn(
f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
"samples were used for the validation, which might bias the results. "
"Try adjusting the batch size and / or the world size. "
"Setting the world size to 1 is always a safe bet."
)

metric_logger.synchronize_between_processes()

print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
Expand Down Expand Up @@ -147,7 +166,7 @@ def load_data(traindir, valdir, args):
print("Creating data loaders")
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
Expand All @@ -164,7 +183,11 @@ def main(args):

device = torch.device(args.device)

torch.backends.cudnn.benchmark = True
if args.use_deterministic_algorithms:
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True

train_dir = os.path.join(args.data_path, "train")
val_dir = os.path.join(args.data_path, "val")
Expand Down Expand Up @@ -277,6 +300,10 @@ def main(args):
model_ema.load_state_dict(checkpoint["model_ema"])

if args.test_only:
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deliberately choose not to set torch.use_deterministic_algorithms(True) here because:

  • just removing the cudnn bencharmking is enough to get constant results, at least for resnet18 (maybe not for others)
  • using torch.use_deterministic_algorithms(True) forces the user to set some env variables like CUBLAS_WORKSPACE_CONFIG=:4096:8, otherwise the script would crash.
  • Users can always set the new --use-deterministic-algorithms flag if they really want to


evaluate(model, criterion, data_loader_test, device=device)
return

Expand Down Expand Up @@ -394,6 +421,9 @@ def get_args_parser(add_help=True):
default=0.9,
help="decay factor for Exponential Moving Average of model parameters(default: 0.9)",
)
parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
)

return parser

Expand Down
15 changes: 10 additions & 5 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,7 @@ def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = reduce_across_processes([self.count, self.total])
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
Expand Down Expand Up @@ -400,3 +396,12 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T
os.replace(tmp_path, output_path)

return output_path


def reduce_across_processes(val):
if not is_dist_avail_and_initialized():
return val
t = torch.tensor(val, device="cuda")
dist.barrier()
dist.all_reduce(t)
return t