diff --git a/references/classification/train.py b/references/classification/train.py index a71d337a1b4..9b1994bad57 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -1,6 +1,7 @@ import datetime import os import time +import warnings import presets import torch @@ -54,6 +55,8 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=" model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = f"Test: {log_suffix}" + + num_processed_samples = 0 with torch.no_grad(): for image, target in metric_logger.log_every(data_loader, print_freq, header): image = image.to(device, non_blocking=True) @@ -68,7 +71,23 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=" metric_logger.update(loss=loss.item()) metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) + num_processed_samples += batch_size # gather the stats from all processes + + num_processed_samples = utils.reduce_across_processes(num_processed_samples) + if ( + hasattr(data_loader.dataset, "__len__") + and len(data_loader.dataset) != num_processed_samples + and torch.distributed.get_rank() == 0 + ): + # See FIXME above + warnings.warn( + f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} " + "samples were used for the validation, which might bias the results. " + "Try adjusting the batch size and / or the world size. " + "Setting the world size to 1 is always a safe bet." + ) + metric_logger.synchronize_between_processes() print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}") @@ -147,7 +166,7 @@ def load_data(traindir, valdir, args): print("Creating data loaders") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) - test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) + test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) @@ -164,7 +183,11 @@ def main(args): device = torch.device(args.device) - torch.backends.cudnn.benchmark = True + if args.use_deterministic_algorithms: + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + else: + torch.backends.cudnn.benchmark = True train_dir = os.path.join(args.data_path, "train") val_dir = os.path.join(args.data_path, "val") @@ -277,6 +300,10 @@ def main(args): model_ema.load_state_dict(checkpoint["model_ema"]) if args.test_only: + # We disable the cudnn benchmarking because it can noticeably affect the accuracy + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + evaluate(model, criterion, data_loader_test, device=device) return @@ -394,6 +421,9 @@ def get_args_parser(add_help=True): default=0.9, help="decay factor for Exponential Moving Average of model parameters(default: 0.9)", ) + parser.add_argument( + "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." + ) return parser diff --git a/references/classification/utils.py b/references/classification/utils.py index 5dbb6b8fd24..c186a60fc1e 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -32,11 +32,7 @@ def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ - if not is_dist_avail_and_initialized(): - return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") - dist.barrier() - dist.all_reduce(t) + t = reduce_across_processes([self.count, self.total]) t = t.tolist() self.count = int(t[0]) self.total = t[1] @@ -400,3 +396,12 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T os.replace(tmp_path, output_path) return output_path + + +def reduce_across_processes(val): + if not is_dist_avail_and_initialized(): + return val + t = torch.tensor(val, device="cuda") + dist.barrier() + dist.all_reduce(t) + return t