1
1
import datetime
2
2
import os
3
3
import time
4
+ import warnings
4
5
5
6
import presets
6
7
import torch
@@ -54,6 +55,8 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="
54
55
model .eval ()
55
56
metric_logger = utils .MetricLogger (delimiter = " " )
56
57
header = f"Test: { log_suffix } "
58
+
59
+ num_processed_samples = 0
57
60
with torch .no_grad ():
58
61
for image , target in metric_logger .log_every (data_loader , print_freq , header ):
59
62
image = image .to (device , non_blocking = True )
@@ -68,7 +71,23 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="
68
71
metric_logger .update (loss = loss .item ())
69
72
metric_logger .meters ["acc1" ].update (acc1 .item (), n = batch_size )
70
73
metric_logger .meters ["acc5" ].update (acc5 .item (), n = batch_size )
74
+ num_processed_samples += batch_size
71
75
# gather the stats from all processes
76
+
77
+ num_processed_samples = utils .reduce_across_processes (num_processed_samples )
78
+ if (
79
+ hasattr (data_loader .dataset , "__len__" )
80
+ and len (data_loader .dataset ) != num_processed_samples
81
+ and torch .distributed .get_rank () == 0
82
+ ):
83
+ # See FIXME above
84
+ warnings .warn (
85
+ f"It looks like the dataset has { len (data_loader .dataset )} samples, but { num_processed_samples } "
86
+ "samples were used for the validation, which might bias the results. "
87
+ "Try adjusting the batch size and / or the world size. "
88
+ "Setting the world size to 1 is always a safe bet."
89
+ )
90
+
72
91
metric_logger .synchronize_between_processes ()
73
92
74
93
print (f"{ header } Acc@1 { metric_logger .acc1 .global_avg :.3f} Acc@5 { metric_logger .acc5 .global_avg :.3f} " )
@@ -147,7 +166,7 @@ def load_data(traindir, valdir, args):
147
166
print ("Creating data loaders" )
148
167
if args .distributed :
149
168
train_sampler = torch .utils .data .distributed .DistributedSampler (dataset )
150
- test_sampler = torch .utils .data .distributed .DistributedSampler (dataset_test )
169
+ test_sampler = torch .utils .data .distributed .DistributedSampler (dataset_test , shuffle = False )
151
170
else :
152
171
train_sampler = torch .utils .data .RandomSampler (dataset )
153
172
test_sampler = torch .utils .data .SequentialSampler (dataset_test )
@@ -164,7 +183,11 @@ def main(args):
164
183
165
184
device = torch .device (args .device )
166
185
167
- torch .backends .cudnn .benchmark = True
186
+ if args .use_deterministic_algorithms :
187
+ torch .backends .cudnn .benchmark = False
188
+ torch .use_deterministic_algorithms (True )
189
+ else :
190
+ torch .backends .cudnn .benchmark = True
168
191
169
192
train_dir = os .path .join (args .data_path , "train" )
170
193
val_dir = os .path .join (args .data_path , "val" )
@@ -277,6 +300,10 @@ def main(args):
277
300
model_ema .load_state_dict (checkpoint ["model_ema" ])
278
301
279
302
if args .test_only :
303
+ # We disable the cudnn benchmarking because it can noticeably affect the accuracy
304
+ torch .backends .cudnn .benchmark = False
305
+ torch .backends .cudnn .deterministic = True
306
+
280
307
evaluate (model , criterion , data_loader_test , device = device )
281
308
return
282
309
@@ -394,6 +421,9 @@ def get_args_parser(add_help=True):
394
421
default = 0.9 ,
395
422
help = "decay factor for Exponential Moving Average of model parameters(default: 0.9)" ,
396
423
)
424
+ parser .add_argument (
425
+ "--use-deterministic-algorithms" , action = "store_true" , help = "Forces the use of deterministic algorithms only."
426
+ )
397
427
398
428
return parser
399
429
0 commit comments