diff --git a/coco.py b/coco.py index f9dc44a..7f6efb7 100644 --- a/coco.py +++ b/coco.py @@ -50,6 +50,7 @@ import model as modellib import torch +import imgaug.augmenters as iaa # Root directory of the project ROOT_DIR = os.getcwd() @@ -474,6 +475,40 @@ class InferenceConfig(CocoConfig): print("Loading weights ", model_path) model.load_weights(model_path) + + + augmentation = iaa.Sometimes(.667, iaa.Sequential([ + iaa.Fliplr(0.5), # horizontal flips + iaa.Crop(percent=(0, 0.1)), # random crops + # Small gaussian blur with random sigma between 0 and 0.25. + # But we only blur about 50% of all images. + iaa.Sometimes(0.5, + iaa.GaussianBlur(sigma=(0, 0.25)) + ), + # Strengthen or weaken the contrast in each image. + iaa.ContrastNormalization((0.75, 1.5)), + # Add gaussian noise. + # For 50% of all images, we sample the noise once per pixel. + # For the other 50% of all images, we sample the noise per pixel AND + # channel. This can change the color (not only brightness) of the + # pixels. + iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 255)), + # Make some images brighter and some darker. + # In 20% of all cases, we sample the multiplier once per channel, + # which can end up changing the color of the images. + iaa.Multiply((0.8, 1.2)), + # Apply affine transformations to each image. + # Scale/zoom them, translate/move them, rotate them and shear them. + iaa.Affine( + scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, + # translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, + rotate=(-180, 180), + # shear=(-8, 8) + ) + ], random_order=True)) # apply augmenters in random order + + + # Train or evaluate if args.command == "train": # Training dataset. Use the training set and 35K from the @@ -495,7 +530,8 @@ class InferenceConfig(CocoConfig): model.train_model(dataset_train, dataset_val, learning_rate=config.LEARNING_RATE, epochs=40, - layers='heads') + layers='heads', + augmentation=augmentation) # Training - Stage 2 # Finetune layers from ResNet stage 4 and up @@ -503,7 +539,8 @@ class InferenceConfig(CocoConfig): model.train_model(dataset_train, dataset_val, learning_rate=config.LEARNING_RATE, epochs=120, - layers='4+') + layers='4+', + augmentation=augmentation) # Training - Stage 3 # Fine tune all layers @@ -511,7 +548,8 @@ class InferenceConfig(CocoConfig): model.train_model(dataset_train, dataset_val, learning_rate=config.LEARNING_RATE / 10, epochs=160, - layers='all') + layers='all', + augmentation=augmentation) elif args.command == "evaluate": # Validation dataset diff --git a/model.py b/model.py index 302e111..95ba74f 100644 --- a/model.py +++ b/model.py @@ -1137,7 +1137,7 @@ def compute_losses(rpn_match, rpn_bbox, rpn_class_logits, rpn_pred_bbox, target_ ############################################################ def load_image_gt(dataset, config, image_id, augment=False, - use_mini_mask=False): + use_mini_mask=False, augmentation=None): """Load and return ground truth data for an image (image, mask, bounding boxes). augment: If true, apply random image augmentation. Currently, only @@ -1174,6 +1174,34 @@ def load_image_gt(dataset, config, image_id, augment=False, image = np.fliplr(image) mask = np.fliplr(mask) + if augmentation: + import imgaug + + # Augmenters that are safe to apply to masks + # Some, such as Affine, have settings that make them unsafe, so always + # test your augmentation on masks + MASK_AUGMENTERS = ["Sequential", "SomeOf", "OneOf", "Sometimes", + "Fliplr", "Flipud", "CropAndPad", + "Affine", "PiecewiseAffine"] + + def hook(images, augmenter, parents, default): + """Determines which augmenters to apply to masks.""" + return augmenter.__class__.__name__ in MASK_AUGMENTERS + + # Store shapes before augmentation to compare + image_shape = image.shape + mask_shape = mask.shape + # Make augmenters deterministic to apply similarly to images and masks + det = augmentation.to_deterministic() + image = det.augment_image(image) + mask = det.augment_image(mask.astype(np.uint8), + hooks=imgaug.HooksImages(activator=hook)) + # Verify that shapes didn't change + assert image.shape == image_shape, "Augmentation shouldn't change image size" + assert mask.shape == mask_shape, "Augmentation shouldn't change mask size" + # Change mask back to bool + mask = mask.astype(np.bool) + # Bounding boxes. Note that some boxes might be all zeros # if the corresponding mask got cropped out. # bbox: [num_instances, (y1, x1, y2, x2)] @@ -1306,7 +1334,7 @@ def build_rpn_targets(image_shape, anchors, gt_class_ids, gt_boxes, config): return rpn_match, rpn_bbox class Dataset(torch.utils.data.Dataset): - def __init__(self, dataset, config, augment=True): + def __init__(self, dataset, config, augment=True, augmentation=None): """A generator that returns images and corresponding target class ids, bounding box deltas, and masks. @@ -1342,6 +1370,7 @@ def __init__(self, dataset, config, augment=True): self.dataset = dataset self.config = config self.augment = augment + self.augmentation = augmentation # Anchors # [anchor_count, (y1, x1, y2, x2)] @@ -1356,7 +1385,7 @@ def __getitem__(self, image_index): image_id = self.image_ids[image_index] image, image_metas, gt_class_ids, gt_boxes, gt_masks = \ load_image_gt(self.dataset, self.config, image_id, augment=self.augment, - use_mini_mask=self.config.USE_MINI_MASK) + use_mini_mask=self.config.USE_MINI_MASK, augmentation=self.augmentation) # Skip images that have no instances. This can happen in cases # where we train on a subset of classes and the image doesn't @@ -1733,7 +1762,7 @@ def set_bn_eval(m): return [rpn_class_logits, rpn_bbox, target_class_ids, mrcnn_class_logits, target_deltas, mrcnn_bbox, target_mask, mrcnn_mask] - def train_model(self, train_dataset, val_dataset, learning_rate, epochs, layers): + def train_model(self, train_dataset, val_dataset, learning_rate, epochs, layers, augmentation=None): """Train the model. train_dataset, val_dataset: Training and validation Dataset objects. learning_rate: The learning rate to train with @@ -1766,7 +1795,7 @@ def train_model(self, train_dataset, val_dataset, learning_rate, epochs, layers) layers = layer_regex[layers] # Data generators - train_set = Dataset(train_dataset, self.config, augment=True) + train_set = Dataset(train_dataset, self.config, augment=True, augmentation=augmentation) train_generator = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True, num_workers=4) val_set = Dataset(val_dataset, self.config, augment=True) val_generator = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=True, num_workers=4)