Skip to content

Commit 5b81c05

Browse files
authored
Reduce variance of classification references evaluation (#4609)
1 parent 23f413c commit 5b81c05

File tree

2 files changed

+42
-7
lines changed

2 files changed

+42
-7
lines changed

references/classification/train.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime
22
import os
33
import time
4+
import warnings
45

56
import presets
67
import torch
@@ -54,6 +55,8 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="
5455
model.eval()
5556
metric_logger = utils.MetricLogger(delimiter=" ")
5657
header = f"Test: {log_suffix}"
58+
59+
num_processed_samples = 0
5760
with torch.no_grad():
5861
for image, target in metric_logger.log_every(data_loader, print_freq, header):
5962
image = image.to(device, non_blocking=True)
@@ -68,7 +71,23 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="
6871
metric_logger.update(loss=loss.item())
6972
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
7073
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
74+
num_processed_samples += batch_size
7175
# gather the stats from all processes
76+
77+
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
78+
if (
79+
hasattr(data_loader.dataset, "__len__")
80+
and len(data_loader.dataset) != num_processed_samples
81+
and torch.distributed.get_rank() == 0
82+
):
83+
# See FIXME above
84+
warnings.warn(
85+
f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
86+
"samples were used for the validation, which might bias the results. "
87+
"Try adjusting the batch size and / or the world size. "
88+
"Setting the world size to 1 is always a safe bet."
89+
)
90+
7291
metric_logger.synchronize_between_processes()
7392

7493
print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
@@ -147,7 +166,7 @@ def load_data(traindir, valdir, args):
147166
print("Creating data loaders")
148167
if args.distributed:
149168
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
150-
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
169+
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
151170
else:
152171
train_sampler = torch.utils.data.RandomSampler(dataset)
153172
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
@@ -164,7 +183,11 @@ def main(args):
164183

165184
device = torch.device(args.device)
166185

167-
torch.backends.cudnn.benchmark = True
186+
if args.use_deterministic_algorithms:
187+
torch.backends.cudnn.benchmark = False
188+
torch.use_deterministic_algorithms(True)
189+
else:
190+
torch.backends.cudnn.benchmark = True
168191

169192
train_dir = os.path.join(args.data_path, "train")
170193
val_dir = os.path.join(args.data_path, "val")
@@ -277,6 +300,10 @@ def main(args):
277300
model_ema.load_state_dict(checkpoint["model_ema"])
278301

279302
if args.test_only:
303+
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
304+
torch.backends.cudnn.benchmark = False
305+
torch.backends.cudnn.deterministic = True
306+
280307
evaluate(model, criterion, data_loader_test, device=device)
281308
return
282309

@@ -394,6 +421,9 @@ def get_args_parser(add_help=True):
394421
default=0.9,
395422
help="decay factor for Exponential Moving Average of model parameters(default: 0.9)",
396423
)
424+
parser.add_argument(
425+
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
426+
)
397427

398428
return parser
399429

references/classification/utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,7 @@ def synchronize_between_processes(self):
3232
"""
3333
Warning: does not synchronize the deque!
3434
"""
35-
if not is_dist_avail_and_initialized():
36-
return
37-
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
38-
dist.barrier()
39-
dist.all_reduce(t)
35+
t = reduce_across_processes([self.count, self.total])
4036
t = t.tolist()
4137
self.count = int(t[0])
4238
self.total = t[1]
@@ -400,3 +396,12 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T
400396
os.replace(tmp_path, output_path)
401397

402398
return output_path
399+
400+
401+
def reduce_across_processes(val):
402+
if not is_dist_avail_and_initialized():
403+
return val
404+
t = torch.tensor(val, device="cuda")
405+
dist.barrier()
406+
dist.all_reduce(t)
407+
return t

0 commit comments

Comments
 (0)