Skip to content

Add support for ObjectDetection #758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions conf/nasa_marine_debris.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
program:
seed: 0
overwrite: True

trainer:
gpus: 1
min_epochs: 5
max_epochs: 100
auto_lr_find: False
benchmark: True

experiment:
task: "nasa_marine_debris"
name: "nasa_marine_debris_test"
module:
detection_model: "faster-rcnn"
backbone: "resnet50"
pretrained: True
num_classes: 2
learning_rate: 1.2e-4
learning_rate_schedule_patience: 6
verbose: false
datamodule:
root: "data/nasamr/nasa_marine_debris"
batch_size: 4
num_workers: 56
val_split_pct: 0.2
31 changes: 30 additions & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
import torch
from torchmetrics import Accuracy, JaccardIndex, MetricCollection

from torchgeo.trainers import ClassificationTask, SemanticSegmentationTask
from torchgeo.trainers import (
ClassificationTask,
ObjectDetectionTask,
SemanticSegmentationTask,
)
from train import TASK_TO_MODULES_MAPPING


Expand Down Expand Up @@ -106,6 +110,14 @@ def run_eval_loop(
y = batch["mask"].to(device)
elif "label" in batch:
y = batch["label"].to(device)
elif "boxes" in batch:
y = [
{
"boxes": batch["boxes"][i].to(device),
"labels": batch["labels"][i].to(device),
}
for i in range(len(batch["image"]))
]
with torch.inference_mode():
y_pred = model(x)
metrics(y_pred, y)
Expand Down Expand Up @@ -176,6 +188,20 @@ def main(args: argparse.Namespace) -> None:
"learning_rate": model.hparams["learning_rate"],
"loss": model.hparams["loss"],
}
elif issubclass(TASK, ObjectDetectionTask):
val_row = {
"split": "val",
"detection_model": model.hparams["detection_model"],
"backbone": model.hparams["backbone"],
"learning_rate": model.hparams["learning_rate"],
}

test_row = {
"split": "test",
"detection_model": model.hparams["detection_model"],
"backbone": model.hparams["backbone"],
"learning_rate": model.hparams["learning_rate"],
}
else:
raise ValueError(f"{TASK} is not supported")

Expand Down Expand Up @@ -240,6 +266,9 @@ def main(args: argparse.Namespace) -> None:
"jaccard_index": test_results["test_JaccardIndex"].item(),
}
)
elif issubclass(TASK, ObjectDetectionTask):
val_row.update({"map": val_results["map"].item()})
test_row.update({"map": test_results["map"].item()})

assert set(val_row.keys()) == set(test_row.keys())
fieldnames = list(test_row.keys())
Expand Down
13 changes: 13 additions & 0 deletions tests/conf/nasa_marine_debris.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
experiment:
task: "nasa_marine_debris"
module:
detection_model: "faster-rcnn"
backbone: "resnet18"
num_classes: 2
learning_rate: 1.2e-4
learning_rate_schedule_patience: 6
verbose: false
datamodule:
root: "tests/data/nasa_marine_debris"
batch_size: 1
num_workers: 0
72 changes: 72 additions & 0 deletions tests/trainers/test_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from typing import Any, Dict, Type, cast

import pytest
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer

from torchgeo.datamodules import NASAMarineDebrisDataModule
from torchgeo.trainers import ObjectDetectionTask


class TestObjectDetectionTask:
@pytest.mark.parametrize(
"name,classname", [("nasa_marine_debris", NASAMarineDebrisDataModule)]
)
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", f"{name}.yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)

# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
datamodule = classname(**datamodule_kwargs)

# Instantiate model
model_kwargs = conf_dict["module"]
model = ObjectDetectionTask(**model_kwargs)

# Instantiate trainer
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1)
trainer.fit(model=model, datamodule=datamodule)
trainer.test(model=model, datamodule=datamodule)
trainer.predict(model=model, dataloaders=datamodule.val_dataloader())

@pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]:
return {
"detection_model": "faster-rcnn",
"backbone": "resnet18",
"num_classes": 2,
}

def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["detection_model"] = "invalid_model"
match = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=match):
ObjectDetectionTask(**model_kwargs)

def test_invalid_backbone(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["backbone"] = "invalid_backbone"
match = "Backbone type 'invalid_backbone' is not valid."
with pytest.raises(ValueError, match=match):
ObjectDetectionTask(**model_kwargs)

def test_non_pretrained_backbone(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["pretrained"] = False
ObjectDetectionTask(**model_kwargs)

def test_missing_attributes(
self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch
) -> None:
monkeypatch.delattr(NASAMarineDebrisDataModule, "plot")
datamodule = NASAMarineDebrisDataModule(
root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0
)
model = ObjectDetectionTask(**model_kwargs)
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1)
trainer.validate(model=model, datamodule=datamodule)
13 changes: 11 additions & 2 deletions torchgeo/datamodules/nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Any, Dict, List, Optional

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from torch import Tensor
Expand All @@ -30,6 +31,7 @@ def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]:
output: Dict[str, Any] = {}
output["image"] = torch.stack([sample["image"] for sample in batch])
output["boxes"] = [sample["boxes"] for sample in batch]
output["labels"] = [torch.tensor([1] * len(sample["boxes"])) for sample in batch]
return output


Expand Down Expand Up @@ -92,9 +94,9 @@ def setup(self, stage: Optional[str] = None) -> None:
Args:
stage: stage to set up
"""
dataset = NASAMarineDebris(transforms=self.preprocess, **self.kwargs)
self.dataset = NASAMarineDebris(transforms=self.preprocess, **self.kwargs)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
)

def train_dataloader(self) -> DataLoader[Any]:
Expand Down Expand Up @@ -138,3 +140,10 @@ def test_dataloader(self) -> DataLoader[Any]:
shuffle=False,
collate_fn=collate_fn,
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.NASAMarineDebris.plot`.

.. versionadded:: 0.4
"""
return self.dataset.plot(*args, **kwargs)
6 changes: 6 additions & 0 deletions torchgeo/datasets/nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
boxes = self._load_target(self.files[index]["target"])
sample = {"image": image, "boxes": boxes}

# Filter invalid boxes
w_check = (sample["boxes"][:, 2] - sample["boxes"][:, 0]) > 0
h_check = (sample["boxes"][:, 3] - sample["boxes"][:, 1]) > 0
indices = w_check & h_check
sample["boxes"] = sample["boxes"][indices]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How many of the boxes were invalid?

@isaaccorley recently had to do the same thing for IDTReeS, not sure if we should try to reuse code for that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can reuse the IDTReeS._filter_boxes method and make it a function so it can be used for both datasets. Doesn't have to be in this PR but it seems as though this may be a common issue with object detection datasets.


if self.transforms is not None:
sample = self.transforms(sample)

Expand Down
2 changes: 2 additions & 0 deletions torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

from .byol import BYOLTask
from .classification import ClassificationTask, MultiLabelClassificationTask
from .detection import ObjectDetectionTask
from .regression import RegressionTask
from .segmentation import SemanticSegmentationTask

__all__ = (
"BYOLTask",
"ClassificationTask",
"MultiLabelClassificationTask",
"ObjectDetectionTask",
"RegressionTask",
"SemanticSegmentationTask",
)
Expand Down
Loading