16
16
17
17
18
18
parser = argparse .ArgumentParser (description = 'PyTorch ImageNet Training' )
19
- parser .add_argument ('-- data' , metavar = 'PATH' , required = True ,
19
+ parser .add_argument ('data' , metavar = 'DIR' ,
20
20
help = 'path to dataset' )
21
21
parser .add_argument ('--arch' , '-a' , metavar = 'ARCH' , default = 'resnet18' ,
22
- help = 'model architecture: resnet18 | resnet34 | ...'
22
+ help = 'model architecture: resnet18 | resnet34 | ... '
23
23
'(default: resnet18)' )
24
24
parser .add_argument ('-j' , '--workers' , default = 4 , type = int , metavar = 'N' ,
25
25
help = 'number of data loading workers (default: 4)' )
39
39
metavar = 'N' , help = 'print frequency (default: 10)' )
40
40
parser .add_argument ('--resume' , default = '' , type = str , metavar = 'PATH' ,
41
41
help = 'path to latest checkpoint (default: none)' )
42
+ parser .add_argument ('-e' , '--evaluate' , type = str , metavar = 'FILE' ,
43
+ help = 'evaluate model FILE on validation set' )
42
44
43
45
best_prec1 = 0
44
46
@@ -50,20 +52,28 @@ def main():
50
52
# create model
51
53
if args .arch .startswith ('resnet' ):
52
54
print ("=> creating model '{}'" .format (args .arch ))
53
- model = resnet .__dict__ [args .arch ]()
55
+ model = torch . nn . DataParallel ( resnet .__dict__ [args .arch ]() )
54
56
model .cuda ()
55
57
else :
56
58
parser .error ('invalid architecture: {}' .format (args .arch ))
57
59
58
60
# optionally resume from a checkpoint
59
- if args .resume :
61
+ if args .evaluate :
62
+ if not os .path .isfile (args .evaluate ):
63
+ parser .error ('invalid checkpoint: {}' .format (args .evaluate ))
64
+ checkpoint = torch .load (args .evaluate )
65
+ model .load_state_dict (checkpoint ['state_dict' ])
66
+ print ("=> loaded checkpoint '{}' (epoch {})"
67
+ .format (args .evaluate , checkpoint ['epoch' ]))
68
+ elif args .resume :
60
69
if os .path .isfile (args .resume ):
61
70
print ("=> loading checkpoint '{}'" .format (args .resume ))
62
71
checkpoint = torch .load (args .resume )
63
72
args .start_epoch = checkpoint ['epoch' ]
64
73
best_prec1 = checkpoint ['best_prec1' ]
65
74
model .load_state_dict (checkpoint ['state_dict' ])
66
- print (" | resuming from epoch {}" .format (args .start_epoch ))
75
+ print ("=> loaded checkpoint '{}' (epoch {})"
76
+ .format (args .evaluate , checkpoint ['epoch' ]))
67
77
else :
68
78
print ("=> no checkpoint found at '{}'" .format (args .resume ))
69
79
@@ -95,32 +105,31 @@ def main():
95
105
batch_size = args .batch_size , shuffle = False ,
96
106
num_workers = args .workers , pin_memory = True )
97
107
98
- # parallelize model across all visible GPUs
99
- model = torch .nn .DataParallel (model )
100
-
101
108
# define loss function (criterion) and pptimizer
102
109
criterion = nn .CrossEntropyLoss ().cuda ()
103
110
104
111
optimizer = torch .optim .SGD (model .parameters (), args .lr ,
105
112
momentum = args .momentum ,
106
113
weight_decay = args .weight_decay )
107
114
115
+ if args .evaluate :
116
+ validate (val_loader , model , criterion )
117
+ return
118
+
108
119
for epoch in range (args .start_epoch , args .epochs ):
109
120
adjust_learning_rate (optimizer , epoch )
110
121
111
122
# train for one epoch
112
- model .train ()
113
123
train (train_loader , model , criterion , optimizer , epoch )
114
124
115
125
# evaluate on validation set
116
- model .eval ()
117
126
prec1 = validate (val_loader , model , criterion )
118
127
119
128
# remember best prec@1 and save checkpoint
120
129
is_best = prec1 > best_prec1
121
130
best_prec1 = max (prec1 , best_prec1 )
122
131
save_checkpoint ({
123
- 'epoch' : epoch ,
132
+ 'epoch' : epoch + 1 ,
124
133
'arch' : args .arch ,
125
134
'state_dict' : model .state_dict (),
126
135
'best_prec1' : best_prec1 ,
@@ -134,6 +143,9 @@ def train(train_loader, model, criterion, optimizer, epoch):
134
143
top1 = AverageMeter ()
135
144
top5 = AverageMeter ()
136
145
146
+ # switch to train mode
147
+ model .train ()
148
+
137
149
end = time .time ()
138
150
for i , (input , target ) in enumerate (train_loader ):
139
151
# measure data loading time
@@ -149,9 +161,9 @@ def train(train_loader, model, criterion, optimizer, epoch):
149
161
150
162
# measure accuracy and record loss
151
163
prec1 , prec5 = accuracy (output .data , target , topk = (1 , 5 ))
152
- losses .update (loss .data [0 ])
153
- top1 .update (prec1 [0 ])
154
- top5 .update (prec5 [0 ])
164
+ losses .update (loss .data [0 ], input . size ( 0 ) )
165
+ top1 .update (prec1 [0 ], input . size ( 0 ) )
166
+ top5 .update (prec5 [0 ], input . size ( 0 ) )
155
167
156
168
# compute gradient and do SGD step
157
169
optimizer .zero_grad ()
@@ -179,6 +191,9 @@ def validate(val_loader, model, criterion):
179
191
top1 = AverageMeter ()
180
192
top5 = AverageMeter ()
181
193
194
+ # switch to evaluate mode
195
+ model .eval ()
196
+
182
197
end = time .time ()
183
198
for i , (input , target ) in enumerate (val_loader ):
184
199
target = target .cuda (async = True )
@@ -191,9 +206,9 @@ def validate(val_loader, model, criterion):
191
206
192
207
# measure accuracy and record loss
193
208
prec1 , prec5 = accuracy (output .data , target , topk = (1 , 5 ))
194
- losses .update (loss .data [0 ])
195
- top1 .update (prec1 [0 ])
196
- top5 .update (prec5 [0 ])
209
+ losses .update (loss .data [0 ], input . size ( 0 ) )
210
+ top1 .update (prec1 [0 ], input . size ( 0 ) )
211
+ top5 .update (prec5 [0 ], input . size ( 0 ) )
197
212
198
213
# measure elapsed time
199
214
batch_time .update (time .time () - end )
@@ -208,6 +223,9 @@ def validate(val_loader, model, criterion):
208
223
i , len (val_loader ), batch_time = batch_time , loss = losses ,
209
224
top1 = top1 , top5 = top5 ))
210
225
226
+ print (' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
227
+ .format (top1 = top1 , top5 = top5 ))
228
+
211
229
return top1 .avg
212
230
213
231
@@ -226,13 +244,13 @@ def reset(self):
226
244
self .val = 0
227
245
self .avg = 0
228
246
self .sum = 0
229
- self .n = 0
247
+ self .count = 0
230
248
231
- def update (self , val ):
249
+ def update (self , val , n = 1 ):
232
250
self .val = val
233
- self .sum += val
234
- self .n += 1
235
- self .avg = self .sum / self .n
251
+ self .sum += val * n
252
+ self .count += n
253
+ self .avg = self .sum / self .count
236
254
237
255
238
256
def adjust_learning_rate (optimizer , epoch ):
@@ -247,7 +265,7 @@ def accuracy(output, target, topk=(1,)):
247
265
maxk = max (topk )
248
266
batch_size = target .size (0 )
249
267
250
- _ , pred = output .topk (maxk , True , True )
268
+ _ , pred = output .topk (maxk , 1 , True , True )
251
269
pred = pred .t ()
252
270
correct = pred .eq (target .view (1 , - 1 ).expand_as (pred ))
253
271
0 commit comments