Skip to content

Commit 8e56a37

Browse files
committed
Remove support of targets on presets.
1 parent a696473 commit 8e56a37

File tree

4 files changed

+17
-43
lines changed

4 files changed

+17
-43
lines changed

gallery/plot_optical_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def plot(imgs, **imshow_kwargs):
9696
def preprocess(img1_batch, img2_batch):
9797
img1_batch = F.resize(img1_batch, size=[520, 960])
9898
img2_batch = F.resize(img2_batch, size=[520, 960])
99-
return transforms(img1_batch, img2_batch)[:2]
99+
return transforms(img1_batch, img2_batch)
100100

101101

102102
img1_batch, img2_batch = preprocess(img1_batch, img2_batch)

gallery/plot_repurposing_annotations.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,8 @@ def show(imgs):
146146
print(img.size())
147147

148148
tranforms = weights.transforms()
149-
img, _ = tranforms(img)
150-
target = {}
151-
target["boxes"] = boxes
152-
target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64)
153-
detection_outputs = model(img.unsqueeze(0), [target])
149+
img = tranforms(img)
150+
detection_outputs = model(img.unsqueeze(0))
154151

155152

156153
####################################

gallery/plot_visualization_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def show(imgs):
8181
weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
8282
transforms = weights.transforms()
8383

84-
batch, _ = transforms(batch_int)
84+
batch = transforms(batch_int)
8585

8686
model = fasterrcnn_resnet50_fpn(weights=weights, progress=False)
8787
model = model.eval()
@@ -131,7 +131,7 @@ def show(imgs):
131131
model = fcn_resnet50(weights=weights, progress=False)
132132
model = model.eval()
133133

134-
normalized_batch, _ = transforms(batch)
134+
normalized_batch = transforms(batch)
135135
output = model(normalized_batch)['out']
136136
print(output.shape, output.min().item(), output.max().item())
137137

@@ -272,7 +272,7 @@ def show(imgs):
272272
weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
273273
transforms = weights.transforms()
274274

275-
batch, _ = transforms(batch_int)
275+
batch = transforms(batch_int)
276276

277277
model = maskrcnn_resnet50_fpn(weights=weights, progress=False)
278278
model = model.eval()
@@ -397,7 +397,7 @@ def show(imgs):
397397
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
398398
transforms = weights.transforms()
399399

400-
person_float, _ = transforms(person_int)
400+
person_float = transforms(person_int)
401401

402402
model = keypointrcnn_resnet50_fpn(weights=weights, progress=False)
403403
model = model.eval()

torchvision/transforms/_presets.py

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,10 @@
2020

2121

2222
class ObjectDetectionEval(nn.Module):
23-
def forward(
24-
self, img: Tensor, target: Optional[Dict[str, Tensor]] = None
25-
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
23+
def forward(self, img: Tensor) -> Tensor:
2624
if not isinstance(img, Tensor):
2725
img = F.pil_to_tensor(img)
28-
return F.convert_image_dtype(img, torch.float), target
26+
return F.convert_image_dtype(img, torch.float)
2927

3028

3129
class ImageClassificationEval(nn.Module):
@@ -95,28 +93,22 @@ def __init__(
9593
self._interpolation = interpolation
9694
self._interpolation_target = interpolation_target
9795

98-
def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
96+
def forward(self, img: Tensor) -> Tensor:
9997
if isinstance(self._size, list):
10098
img = F.resize(img, self._size, interpolation=self._interpolation)
10199
if not isinstance(img, Tensor):
102100
img = F.pil_to_tensor(img)
103101
img = F.convert_image_dtype(img, torch.float)
104102
img = F.normalize(img, mean=self._mean, std=self._std)
105-
if target:
106-
if isinstance(self._size, list):
107-
target = F.resize(target, self._size, interpolation=self._interpolation_target)
108-
if not isinstance(target, Tensor):
109-
target = F.pil_to_tensor(target)
110-
target = target.squeeze(0).to(torch.int64)
111-
return img, target
103+
return img
112104

113105

114106
class OpticalFlowEval(nn.Module):
115-
def forward(
116-
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor] = None, valid_flow_mask: Optional[Tensor] = None
117-
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
118-
119-
img1, img2, flow, valid_flow_mask = self._pil_or_numpy_to_tensor(img1, img2, flow, valid_flow_mask)
107+
def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]:
108+
if not isinstance(img1, Tensor):
109+
img1 = F.pil_to_tensor(img1)
110+
if not isinstance(img2, Tensor):
111+
img2 = F.pil_to_tensor(img2)
120112

121113
img1 = F.convert_image_dtype(img1, torch.float32)
122114
img2 = F.convert_image_dtype(img2, torch.float32)
@@ -128,19 +120,4 @@ def forward(
128120
img1 = img1.contiguous()
129121
img2 = img2.contiguous()
130122

131-
return img1, img2, flow, valid_flow_mask
132-
133-
def _pil_or_numpy_to_tensor(
134-
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor]
135-
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
136-
if not isinstance(img1, Tensor):
137-
img1 = F.pil_to_tensor(img1)
138-
if not isinstance(img2, Tensor):
139-
img2 = F.pil_to_tensor(img2)
140-
141-
if flow is not None and not isinstance(flow, Tensor):
142-
flow = torch.from_numpy(flow)
143-
if valid_flow_mask is not None and not isinstance(valid_flow_mask, Tensor):
144-
valid_flow_mask = torch.from_numpy(valid_flow_mask)
145-
146-
return img1, img2, flow, valid_flow_mask
123+
return img1, img2

0 commit comments

Comments
 (0)