Skip to content

More augmentation #85

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import model as modellib

import torch
import imgaug.augmenters as iaa

# Root directory of the project
ROOT_DIR = os.getcwd()
Expand Down Expand Up @@ -474,6 +475,40 @@ class InferenceConfig(CocoConfig):
print("Loading weights ", model_path)
model.load_weights(model_path)



augmentation = iaa.Sometimes(.667, iaa.Sequential([
iaa.Fliplr(0.5), # horizontal flips
iaa.Crop(percent=(0, 0.1)), # random crops
# Small gaussian blur with random sigma between 0 and 0.25.
# But we only blur about 50% of all images.
iaa.Sometimes(0.5,
iaa.GaussianBlur(sigma=(0, 0.25))
),
# Strengthen or weaken the contrast in each image.
iaa.ContrastNormalization((0.75, 1.5)),
# Add gaussian noise.
# For 50% of all images, we sample the noise once per pixel.
# For the other 50% of all images, we sample the noise per pixel AND
# channel. This can change the color (not only brightness) of the
# pixels.
iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 255)),
# Make some images brighter and some darker.
# In 20% of all cases, we sample the multiplier once per channel,
# which can end up changing the color of the images.
iaa.Multiply((0.8, 1.2)),
# Apply affine transformations to each image.
# Scale/zoom them, translate/move them, rotate them and shear them.
iaa.Affine(
scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
# translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
rotate=(-180, 180),
# shear=(-8, 8)
)
], random_order=True)) # apply augmenters in random order



# Train or evaluate
if args.command == "train":
# Training dataset. Use the training set and 35K from the
Expand All @@ -495,23 +530,26 @@ class InferenceConfig(CocoConfig):
model.train_model(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE,
epochs=40,
layers='heads')
layers='heads',
augmentation=augmentation)

# Training - Stage 2
# Finetune layers from ResNet stage 4 and up
print("Fine tune Resnet stage 4 and up")
model.train_model(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE,
epochs=120,
layers='4+')
layers='4+',
augmentation=augmentation)

# Training - Stage 3
# Fine tune all layers
print("Fine tune all layers")
model.train_model(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE / 10,
epochs=160,
layers='all')
layers='all',
augmentation=augmentation)

elif args.command == "evaluate":
# Validation dataset
Expand Down
39 changes: 34 additions & 5 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,7 +1137,7 @@ def compute_losses(rpn_match, rpn_bbox, rpn_class_logits, rpn_pred_bbox, target_
############################################################

def load_image_gt(dataset, config, image_id, augment=False,
use_mini_mask=False):
use_mini_mask=False, augmentation=None):
"""Load and return ground truth data for an image (image, mask, bounding boxes).

augment: If true, apply random image augmentation. Currently, only
Expand Down Expand Up @@ -1174,6 +1174,34 @@ def load_image_gt(dataset, config, image_id, augment=False,
image = np.fliplr(image)
mask = np.fliplr(mask)

if augmentation:
import imgaug

# Augmenters that are safe to apply to masks
# Some, such as Affine, have settings that make them unsafe, so always
# test your augmentation on masks
MASK_AUGMENTERS = ["Sequential", "SomeOf", "OneOf", "Sometimes",
"Fliplr", "Flipud", "CropAndPad",
"Affine", "PiecewiseAffine"]

def hook(images, augmenter, parents, default):
"""Determines which augmenters to apply to masks."""
return augmenter.__class__.__name__ in MASK_AUGMENTERS

# Store shapes before augmentation to compare
image_shape = image.shape
mask_shape = mask.shape
# Make augmenters deterministic to apply similarly to images and masks
det = augmentation.to_deterministic()
image = det.augment_image(image)
mask = det.augment_image(mask.astype(np.uint8),
hooks=imgaug.HooksImages(activator=hook))
# Verify that shapes didn't change
assert image.shape == image_shape, "Augmentation shouldn't change image size"
assert mask.shape == mask_shape, "Augmentation shouldn't change mask size"
# Change mask back to bool
mask = mask.astype(np.bool)

# Bounding boxes. Note that some boxes might be all zeros
# if the corresponding mask got cropped out.
# bbox: [num_instances, (y1, x1, y2, x2)]
Expand Down Expand Up @@ -1306,7 +1334,7 @@ def build_rpn_targets(image_shape, anchors, gt_class_ids, gt_boxes, config):
return rpn_match, rpn_bbox

class Dataset(torch.utils.data.Dataset):
def __init__(self, dataset, config, augment=True):
def __init__(self, dataset, config, augment=True, augmentation=None):
"""A generator that returns images and corresponding target class ids,
bounding box deltas, and masks.

Expand Down Expand Up @@ -1342,6 +1370,7 @@ def __init__(self, dataset, config, augment=True):
self.dataset = dataset
self.config = config
self.augment = augment
self.augmentation = augmentation

# Anchors
# [anchor_count, (y1, x1, y2, x2)]
Expand All @@ -1356,7 +1385,7 @@ def __getitem__(self, image_index):
image_id = self.image_ids[image_index]
image, image_metas, gt_class_ids, gt_boxes, gt_masks = \
load_image_gt(self.dataset, self.config, image_id, augment=self.augment,
use_mini_mask=self.config.USE_MINI_MASK)
use_mini_mask=self.config.USE_MINI_MASK, augmentation=self.augmentation)

# Skip images that have no instances. This can happen in cases
# where we train on a subset of classes and the image doesn't
Expand Down Expand Up @@ -1733,7 +1762,7 @@ def set_bn_eval(m):

return [rpn_class_logits, rpn_bbox, target_class_ids, mrcnn_class_logits, target_deltas, mrcnn_bbox, target_mask, mrcnn_mask]

def train_model(self, train_dataset, val_dataset, learning_rate, epochs, layers):
def train_model(self, train_dataset, val_dataset, learning_rate, epochs, layers, augmentation=None):
"""Train the model.
train_dataset, val_dataset: Training and validation Dataset objects.
learning_rate: The learning rate to train with
Expand Down Expand Up @@ -1766,7 +1795,7 @@ def train_model(self, train_dataset, val_dataset, learning_rate, epochs, layers)
layers = layer_regex[layers]

# Data generators
train_set = Dataset(train_dataset, self.config, augment=True)
train_set = Dataset(train_dataset, self.config, augment=True, augmentation=augmentation)
train_generator = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True, num_workers=4)
val_set = Dataset(val_dataset, self.config, augment=True)
val_generator = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=True, num_workers=4)
Expand Down