@@ -40,16 +40,17 @@ def copypaste_collate_fn(batch):
40
40
return copypaste (* utils .collate_fn (batch ))
41
41
42
42
43
- def get_dataset (name , image_set , transform , data_path , use_v2 ):
44
- paths = {"coco" : (data_path , get_coco , 91 ), "coco_kp" : (data_path , get_coco_kp , 2 )}
45
- p , ds_fn , num_classes = paths [name ]
43
+ def get_dataset (is_train , args ):
44
+ image_set = "train" if is_train else "val"
45
+ paths = {"coco" : (args .data_path , get_coco , 91 ), "coco_kp" : (args .data_path , get_coco_kp , 2 )}
46
+ p , ds_fn , num_classes = paths [args .dataset ]
46
47
47
- ds = ds_fn (p , image_set = image_set , transforms = transform , use_v2 = use_v2 )
48
+ ds = ds_fn (p , image_set = image_set , transforms = get_transform ( is_train , args ), use_v2 = args . use_v2 )
48
49
return ds , num_classes
49
50
50
51
51
- def get_transform (train , args ):
52
- if train :
52
+ def get_transform (is_train , args ):
53
+ if is_train :
53
54
return presets .DetectionPresetTrain (
54
55
data_augmentation = args .data_augmentation , backend = args .backend , use_v2 = args .use_v2
55
56
)
@@ -185,8 +186,8 @@ def main(args):
185
186
# Data loading code
186
187
print ("Loading data" )
187
188
188
- dataset , num_classes = get_dataset (args . dataset , "train" , get_transform ( True , args ), args . data_path , args . use_v2 )
189
- dataset_test , _ = get_dataset (args . dataset , "val" , get_transform ( False , args ), args . data_path , args . use_v2 )
189
+ dataset , num_classes = get_dataset (is_train = True , args = args )
190
+ dataset_test , _ = get_dataset (is_train = False , args = args )
190
191
191
192
print ("Creating data loaders" )
192
193
if args .distributed :
0 commit comments