Skip to content

Commit 1703e4c

Browse files
authored
Adding Preset Transforms in reference scripts (#3317)
* Adding presets in the classification reference scripts. * Adding presets in the object detection reference scripts. * Adding presets in the segmentation reference scripts. * Adding presets in the video classification reference scripts. * Moving flip at the end to align with image classification signature.
1 parent 7621a8e commit 1703e4c

File tree

9 files changed

+143
-77
lines changed

9 files changed

+143
-77
lines changed

references/classification/README.md

+1-14
Original file line numberDiff line numberDiff line change
@@ -124,22 +124,9 @@ Training converges at about 10 epochs.
124124
For post training quant, device is set to CPU. For training, the device is set to CUDA
125125

126126
### Command to evaluate quantized models using the pre-trained weights:
127-
For all quantized models except inception_v3:
127+
For all quantized models:
128128
```
129129
python references/classification/train_quantization.py --data-path='imagenet_full_size/' \
130130
--device='cpu' --test-only --backend='fbgemm' --model='<model_name>'
131131
```
132132

133-
For inception_v3, since it expects tensors with a size of N x 3 x 299 x 299, before running above command,
134-
need to change the input size of dataset_test in train.py to:
135-
```
136-
dataset_test = torchvision.datasets.ImageFolder(
137-
valdir,
138-
transforms.Compose([
139-
transforms.Resize(342),
140-
transforms.CenterCrop(299),
141-
transforms.ToTensor(),
142-
normalize,
143-
]))
144-
```
145-

references/classification/presets.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from torchvision.transforms import autoaugment, transforms
2+
3+
4+
class ClassificationPresetTrain:
5+
def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), hflip_prob=0.5,
6+
auto_augment_policy=None, random_erase_prob=0.0):
7+
trans = [transforms.RandomResizedCrop(crop_size)]
8+
if hflip_prob > 0:
9+
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
10+
if auto_augment_policy is not None:
11+
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
12+
trans.append(autoaugment.AutoAugment(policy=aa_policy))
13+
trans.extend([
14+
transforms.ToTensor(),
15+
transforms.Normalize(mean=mean, std=std),
16+
])
17+
if random_erase_prob > 0:
18+
trans.append(transforms.RandomErasing(p=random_erase_prob))
19+
20+
self.transforms = transforms.Compose(trans)
21+
22+
def __call__(self, img):
23+
return self.transforms(img)
24+
25+
26+
class ClassificationPresetEval:
27+
def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
28+
29+
self.transforms = transforms.Compose([
30+
transforms.Resize(resize_size),
31+
transforms.CenterCrop(crop_size),
32+
transforms.ToTensor(),
33+
transforms.Normalize(mean=mean, std=std),
34+
])
35+
36+
def __call__(self, img):
37+
return self.transforms(img)

references/classification/train.py

+5-23
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import torch.utils.data
77
from torch import nn
88
import torchvision
9-
from torchvision import transforms
109

10+
import presets
1111
import utils
1212

1313
try:
@@ -82,8 +82,7 @@ def _get_cache_path(filepath):
8282
def load_data(traindir, valdir, args):
8383
# Data loading code
8484
print("Loading data")
85-
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
86-
std=[0.229, 0.224, 0.225])
85+
resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224)
8786

8887
print("Loading training data")
8988
st = time.time()
@@ -93,22 +92,10 @@ def load_data(traindir, valdir, args):
9392
print("Loading dataset_train from {}".format(cache_path))
9493
dataset, _ = torch.load(cache_path)
9594
else:
96-
trans = [
97-
transforms.RandomResizedCrop(224),
98-
transforms.RandomHorizontalFlip(),
99-
]
100-
if args.auto_augment is not None:
101-
aa_policy = transforms.AutoAugmentPolicy(args.auto_augment)
102-
trans.append(transforms.AutoAugment(policy=aa_policy))
103-
trans.extend([
104-
transforms.ToTensor(),
105-
normalize,
106-
])
107-
if args.random_erase > 0:
108-
trans.append(transforms.RandomErasing(p=args.random_erase))
10995
dataset = torchvision.datasets.ImageFolder(
11096
traindir,
111-
transforms.Compose(trans))
97+
presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=args.auto_augment,
98+
random_erase_prob=args.random_erase))
11299
if args.cache_dataset:
113100
print("Saving dataset_train to {}".format(cache_path))
114101
utils.mkdir(os.path.dirname(cache_path))
@@ -124,12 +111,7 @@ def load_data(traindir, valdir, args):
124111
else:
125112
dataset_test = torchvision.datasets.ImageFolder(
126113
valdir,
127-
transforms.Compose([
128-
transforms.Resize(256),
129-
transforms.CenterCrop(224),
130-
transforms.ToTensor(),
131-
normalize,
132-
]))
114+
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size))
133115
if args.cache_dataset:
134116
print("Saving dataset_test to {}".format(cache_path))
135117
utils.mkdir(os.path.dirname(cache_path))

references/detection/presets.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import transforms as T
2+
3+
4+
class DetectionPresetTrain:
5+
def __init__(self, hflip_prob=0.5):
6+
trans = [T.ToTensor()]
7+
if hflip_prob > 0:
8+
trans.append(T.RandomHorizontalFlip(hflip_prob))
9+
10+
self.transforms = T.Compose(trans)
11+
12+
def __call__(self, img, target):
13+
return self.transforms(img, target)
14+
15+
16+
class DetectionPresetEval:
17+
def __init__(self):
18+
self.transforms = T.ToTensor()
19+
20+
def __call__(self, img, target):
21+
return self.transforms(img, target)

references/detection/train.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
3333
from engine import train_one_epoch, evaluate
3434

35+
import presets
3536
import utils
36-
import transforms as T
3737

3838

3939
def get_dataset(name, image_set, transform, data_path):
@@ -48,11 +48,7 @@ def get_dataset(name, image_set, transform, data_path):
4848

4949

5050
def get_transform(train):
51-
transforms = []
52-
transforms.append(T.ToTensor())
53-
if train:
54-
transforms.append(T.RandomHorizontalFlip(0.5))
55-
return T.Compose(transforms)
51+
return presets.DetectionPresetTrain() if train else presets.DetectionPresetEval()
5652

5753

5854
def main(args):

references/segmentation/presets.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import transforms as T
2+
3+
4+
class SegmentationPresetTrain:
5+
def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
6+
min_size = int(0.5 * base_size)
7+
max_size = int(2.0 * base_size)
8+
9+
trans = [T.RandomResize(min_size, max_size)]
10+
if hflip_prob > 0:
11+
trans.append(T.RandomHorizontalFlip(hflip_prob))
12+
trans.extend([
13+
T.RandomCrop(crop_size),
14+
T.ToTensor(),
15+
T.Normalize(mean=mean, std=std),
16+
])
17+
self.transforms = T.Compose(trans)
18+
19+
def __call__(self, img, target):
20+
return self.transforms(img, target)
21+
22+
23+
class SegmentationPresetEval:
24+
def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
25+
self.transforms = T.Compose([
26+
T.RandomResize(base_size, base_size),
27+
T.ToTensor(),
28+
T.Normalize(mean=mean, std=std),
29+
])
30+
31+
def __call__(self, img, target):
32+
return self.transforms(img, target)

references/segmentation/train.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torchvision
99

1010
from coco_utils import get_coco
11-
import transforms as T
11+
import presets
1212
import utils
1313

1414

@@ -30,18 +30,7 @@ def get_transform(train):
3030
base_size = 520
3131
crop_size = 480
3232

33-
min_size = int((0.5 if train else 1.0) * base_size)
34-
max_size = int((2.0 if train else 1.0) * base_size)
35-
transforms = []
36-
transforms.append(T.RandomResize(min_size, max_size))
37-
if train:
38-
transforms.append(T.RandomHorizontalFlip(0.5))
39-
transforms.append(T.RandomCrop(crop_size))
40-
transforms.append(T.ToTensor())
41-
transforms.append(T.Normalize(mean=[0.485, 0.456, 0.406],
42-
std=[0.229, 0.224, 0.225]))
43-
44-
return T.Compose(transforms)
33+
return presets.SegmentationPresetTrain(base_size, crop_size) if train else presets.SegmentationPresetEval(base_size)
4534

4635

4736
def criterion(inputs, target):
+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import torch
2+
3+
from torchvision.transforms import transforms
4+
from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW
5+
6+
7+
class VideoClassificationPresetTrain:
8+
def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989),
9+
hflip_prob=0.5):
10+
trans = [
11+
ConvertBHWCtoBCHW(),
12+
transforms.ConvertImageDtype(torch.float32),
13+
transforms.Resize(resize_size),
14+
]
15+
if hflip_prob > 0:
16+
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
17+
trans.extend([
18+
transforms.Normalize(mean=mean, std=std),
19+
transforms.RandomCrop(crop_size),
20+
ConvertBCHWtoCBHW()
21+
])
22+
self.transforms = transforms.Compose(trans)
23+
24+
def __call__(self, x):
25+
return self.transforms(x)
26+
27+
28+
class VideoClassificationPresetEval:
29+
def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)):
30+
self.transforms = transforms.Compose([
31+
ConvertBHWCtoBCHW(),
32+
transforms.ConvertImageDtype(torch.float32),
33+
transforms.Resize(resize_size),
34+
transforms.Normalize(mean=mean, std=std),
35+
transforms.CenterCrop(crop_size),
36+
ConvertBCHWtoCBHW()
37+
])
38+
39+
def __call__(self, x):
40+
return self.transforms(x)

references/video_classification/train.py

+3-21
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
from torch import nn
88
import torchvision
99
import torchvision.datasets.video_utils
10-
from torchvision import transforms as T
1110
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
1211

12+
import presets
1313
import utils
1414

1515
from scheduler import WarmupMultiStepLR
16-
from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW
1716

1817
try:
1918
from apex import amp
@@ -112,21 +111,11 @@ def main(args):
112111
print("Loading data")
113112
traindir = os.path.join(args.data_path, args.train_dir)
114113
valdir = os.path.join(args.data_path, args.val_dir)
115-
normalize = T.Normalize(mean=[0.43216, 0.394666, 0.37645],
116-
std=[0.22803, 0.22145, 0.216989])
117114

118115
print("Loading training data")
119116
st = time.time()
120117
cache_path = _get_cache_path(traindir)
121-
transform_train = torchvision.transforms.Compose([
122-
ConvertBHWCtoBCHW(),
123-
T.ConvertImageDtype(torch.float32),
124-
T.Resize((128, 171)),
125-
T.RandomHorizontalFlip(),
126-
normalize,
127-
T.RandomCrop((112, 112)),
128-
ConvertBCHWtoCBHW()
129-
])
118+
transform_train = presets.VideoClassificationPresetTrain((128, 171), (112, 112))
130119

131120
if args.cache_dataset and os.path.exists(cache_path):
132121
print("Loading dataset_train from {}".format(cache_path))
@@ -154,14 +143,7 @@ def main(args):
154143
print("Loading validation data")
155144
cache_path = _get_cache_path(valdir)
156145

157-
transform_test = torchvision.transforms.Compose([
158-
ConvertBHWCtoBCHW(),
159-
T.ConvertImageDtype(torch.float32),
160-
T.Resize((128, 171)),
161-
normalize,
162-
T.CenterCrop((112, 112)),
163-
ConvertBCHWtoCBHW()
164-
])
146+
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
165147

166148
if args.cache_dataset and os.path.exists(cache_path):
167149
print("Loading dataset_test from {}".format(cache_path))

0 commit comments

Comments
 (0)