Skip to content

Commit 56dc431

Browse files
committed
clean up parameter passing
1 parent 72da655 commit 56dc431

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

references/detection/presets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def get_modules(use_v2):
1616

1717

1818
class DetectionPresetTrain:
19+
# Note: this transform assumes that the input to forward() are always PIL
20+
# images, regardless of the backend parameter.
1921
def __init__(
2022
self,
2123
*,

references/detection/train.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,17 @@ def copypaste_collate_fn(batch):
4040
return copypaste(*utils.collate_fn(batch))
4141

4242

43-
def get_dataset(name, image_set, transform, data_path, use_v2):
44-
paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)}
45-
p, ds_fn, num_classes = paths[name]
43+
def get_dataset(is_train, args):
44+
image_set = "train" if is_train else "val"
45+
paths = {"coco": (args.data_path, get_coco, 91), "coco_kp": (args.data_path, get_coco_kp, 2)}
46+
p, ds_fn, num_classes = paths[args.dataset]
4647

47-
ds = ds_fn(p, image_set=image_set, transforms=transform, use_v2=use_v2)
48+
ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2)
4849
return ds, num_classes
4950

5051

51-
def get_transform(train, args):
52-
if train:
52+
def get_transform(is_train, args):
53+
if is_train:
5354
return presets.DetectionPresetTrain(
5455
data_augmentation=args.data_augmentation, backend=args.backend, use_v2=args.use_v2
5556
)
@@ -185,8 +186,8 @@ def main(args):
185186
# Data loading code
186187
print("Loading data")
187188

188-
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path, args.use_v2)
189-
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path, args.use_v2)
189+
dataset, num_classes = get_dataset(is_train=True, args=args)
190+
dataset_test, _ = get_dataset(is_train=False, args=args)
190191

191192
print("Creating data loaders")
192193
if args.distributed:

0 commit comments

Comments
 (0)