-
Notifications
You must be signed in to change notification settings - Fork 445
Add support for ObjectDetection #758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
31aa89f
2930a34
02b789f
0040b49
6a366b5
0d4f168
c872e63
d4d24ba
074ec1e
835dc8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
program: | ||
seed: 0 | ||
overwrite: True | ||
|
||
trainer: | ||
gpus: 1 | ||
min_epochs: 5 | ||
max_epochs: 100 | ||
auto_lr_find: False | ||
benchmark: True | ||
|
||
experiment: | ||
task: "nasa_marine_debris" | ||
name: "nasa_marine_debris_test" | ||
module: | ||
detection_model: "faster-rcnn" | ||
backbone: "resnet50" | ||
pretrained: True | ||
num_classes: 2 | ||
learning_rate: 1.2e-4 | ||
learning_rate_schedule_patience: 6 | ||
verbose: false | ||
datamodule: | ||
root: "data/nasamr/nasa_marine_debris" | ||
batch_size: 4 | ||
num_workers: 56 | ||
val_split_pct: 0.2 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
experiment: | ||
task: "nasa_marine_debris" | ||
module: | ||
detection_model: "faster-rcnn" | ||
backbone: "resnet18" | ||
num_classes: 2 | ||
learning_rate: 1.2e-4 | ||
learning_rate_schedule_patience: 6 | ||
verbose: false | ||
datamodule: | ||
root: "tests/data/nasa_marine_debris" | ||
batch_size: 1 | ||
num_workers: 0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
from typing import Any, Dict, Type, cast | ||
|
||
import pytest | ||
from _pytest.monkeypatch import MonkeyPatch | ||
from omegaconf import OmegaConf | ||
from pytorch_lightning import LightningDataModule, Trainer | ||
|
||
from torchgeo.datamodules import NASAMarineDebrisDataModule | ||
from torchgeo.trainers import ObjectDetectionTask | ||
|
||
|
||
class TestObjectDetectionTask: | ||
@pytest.mark.parametrize( | ||
"name,classname", [("nasa_marine_debris", NASAMarineDebrisDataModule)] | ||
) | ||
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: | ||
conf = OmegaConf.load(os.path.join("tests", "conf", f"{name}.yaml")) | ||
conf_dict = OmegaConf.to_object(conf.experiment) | ||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) | ||
|
||
# Instantiate datamodule | ||
datamodule_kwargs = conf_dict["datamodule"] | ||
datamodule = classname(**datamodule_kwargs) | ||
|
||
# Instantiate model | ||
model_kwargs = conf_dict["module"] | ||
model = ObjectDetectionTask(**model_kwargs) | ||
|
||
# Instantiate trainer | ||
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) | ||
trainer.fit(model=model, datamodule=datamodule) | ||
trainer.test(model=model, datamodule=datamodule) | ||
trainer.predict(model=model, dataloaders=datamodule.val_dataloader()) | ||
|
||
@pytest.fixture | ||
def model_kwargs(self) -> Dict[Any, Any]: | ||
return { | ||
"detection_model": "faster-rcnn", | ||
"backbone": "resnet18", | ||
"num_classes": 2, | ||
} | ||
|
||
def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None: | ||
model_kwargs["detection_model"] = "invalid_model" | ||
match = "Model type 'invalid_model' is not valid." | ||
with pytest.raises(ValueError, match=match): | ||
ObjectDetectionTask(**model_kwargs) | ||
|
||
def test_invalid_backbone(self, model_kwargs: Dict[Any, Any]) -> None: | ||
model_kwargs["backbone"] = "invalid_backbone" | ||
match = "Backbone type 'invalid_backbone' is not valid." | ||
with pytest.raises(ValueError, match=match): | ||
ObjectDetectionTask(**model_kwargs) | ||
|
||
def test_non_pretrained_backbone(self, model_kwargs: Dict[Any, Any]) -> None: | ||
model_kwargs["pretrained"] = False | ||
ObjectDetectionTask(**model_kwargs) | ||
|
||
def test_missing_attributes( | ||
self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch | ||
) -> None: | ||
monkeypatch.delattr(NASAMarineDebrisDataModule, "plot") | ||
datamodule = NASAMarineDebrisDataModule( | ||
root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0 | ||
) | ||
model = ObjectDetectionTask(**model_kwargs) | ||
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) | ||
trainer.validate(model=model, datamodule=datamodule) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -99,6 +99,12 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: | |
boxes = self._load_target(self.files[index]["target"]) | ||
sample = {"image": image, "boxes": boxes} | ||
|
||
# Filter invalid boxes | ||
w_check = (sample["boxes"][:, 2] - sample["boxes"][:, 0]) > 0 | ||
h_check = (sample["boxes"][:, 3] - sample["boxes"][:, 1]) > 0 | ||
indices = w_check & h_check | ||
sample["boxes"] = sample["boxes"][indices] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How many of the boxes were invalid? @isaaccorley recently had to do the same thing for IDTReeS, not sure if we should try to reuse code for that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can reuse the IDTReeS._filter_boxes method and make it a function so it can be used for both datasets. Doesn't have to be in this PR but it seems as though this may be a common issue with object detection datasets. |
||
|
||
if self.transforms is not None: | ||
sample = self.transforms(sample) | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.