@@ -116,8 +116,8 @@ def main(args):
116
116
117
117
# Data loading code
118
118
print ("Loading data" )
119
- traindir = os .path .join (args .data_path , 'train_avi-480p' )
120
- valdir = os .path .join (args .data_path , 'val_avi-480p' )
119
+ traindir = os .path .join (args .data_path , args . train_dir )
120
+ valdir = os .path .join (args .data_path , args . val_dir )
121
121
normalize = T .Normalize (mean = [0.43216 , 0.394666 , 0.37645 ],
122
122
std = [0.22803 , 0.22145 , 0.216989 ])
123
123
@@ -274,6 +274,8 @@ def parse_args():
274
274
parser = argparse .ArgumentParser (description = 'PyTorch Classification Training' )
275
275
276
276
parser .add_argument ('--data-path' , default = '/datasets01_101/kinetics/070618/' , help = 'dataset' )
277
+ parser .add_argument ('--train-dir' , default = 'train_avi-480p' , help = 'name of train dir' )
278
+ parser .add_argument ('--val-dir' , default = 'val_avi-480p' , help = 'name of val dir' )
277
279
parser .add_argument ('--model' , default = 'r2plus1d_18' , help = 'model' )
278
280
parser .add_argument ('--device' , default = 'cuda' , help = 'device' )
279
281
parser .add_argument ('--clip-len' , default = 16 , type = int , metavar = 'N' ,
0 commit comments