diff --git a/references/classification/train.py b/references/classification/train.py index 0e1a3878bd5..a3e4c9ad8e9 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -71,8 +71,7 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=' # gather the stats from all processes metric_logger.synchronize_between_processes() - print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}' - .format(top1=metric_logger.acc1, top5=metric_logger.acc5)) + print(f'{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}') return metric_logger.acc1.global_avg