Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def forward(self, *inputs: Any) -> Any:

orig_dims = list(image_or_video.shape)
expected_ndim = 5 if isinstance(orig_image_or_video, features.Video) else 4
batch = image_or_video.view([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)

# Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
Expand All @@ -497,9 +497,9 @@ def forward(self, *inputs: Any) -> Any:
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.
combined_weights = self._sample_dirichlet(
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
) * m[:, 1].view([batch_dims[0], -1])
) * m[:, 1].reshape([batch_dims[0], -1])

mix = m[:, 0].view(batch_dims) * batch
mix = m[:, 0].reshape(batch_dims) * batch
for i in range(self.mixture_width):
aug = batch
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
Expand All @@ -517,8 +517,8 @@ def forward(self, *inputs: Any) -> Any:
aug = self._apply_image_or_video_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
mix = mix.view(orig_dims).to(dtype=image_or_video.dtype)
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)

if isinstance(orig_image_or_video, (features.Image, features.Video)):
mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type]
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def _transform(
f"Got {inpt.device} vs {self.mean_vector.device}"
)

flat_tensor = inpt.view(-1, n) - self.mean_vector
flat_tensor = inpt.reshape(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
return transformed_tensor.view(shape)
return transformed_tensor.reshape(shape)


class Normalize(Transform):
Expand Down
8 changes: 4 additions & 4 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
shape = image.shape

if image.ndim > 4:
image = image.view(-1, num_channels, height, width)
image = image.reshape(-1, num_channels, height, width)
needs_unsquash = True
else:
needs_unsquash = False

output = _FT._blend(image, _FT._blurred_degenerate_image(image), sharpness_factor)

if needs_unsquash:
output = output.view(shape)
output = output.reshape(shape)

return output

Expand Down Expand Up @@ -213,7 +213,7 @@ def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor:
zeros = lut.new_zeros((1, 1)).expand(shape[0], 1)
lut = torch.cat([zeros, lut[:, :-1]], dim=1)

return torch.where((step == 0).unsqueeze(-1), img, lut.gather(dim=1, index=flat_img).view_as(img))
return torch.where((step == 0).unsqueeze(-1), img, lut.gather(dim=1, index=flat_img).reshape_as(img))


def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
Expand All @@ -227,7 +227,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.numel() == 0:
return image

return _equalize_image_tensor_vec(image.view(-1, height, width)).reshape(image.shape)
return _equalize_image_tensor_vec(image.reshape(-1, height, width)).reshape(image.shape)


equalize_image_pil = _FP.equalize
Expand Down
66 changes: 33 additions & 33 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ def horizontal_flip_bounding_box(

bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
).reshape(-1, 4)

bounding_box[:, [0, 2]] = spatial_size[1] - bounding_box[:, [2, 0]]

return convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(shape)
).reshape(shape)


def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -75,13 +75,13 @@ def vertical_flip_bounding_box(

bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
).reshape(-1, 4)

bounding_box[:, [1, 3]] = spatial_size[0] - bounding_box[:, [3, 1]]

return convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(shape)
).reshape(shape)


def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -123,7 +123,7 @@ def resize_image_tensor(
extra_dims = image.shape[:-3]

if image.numel() > 0:
image = image.view(-1, num_channels, old_height, old_width)
image = image.reshape(-1, num_channels, old_height, old_width)

image = _FT.resize(
image,
Expand All @@ -132,7 +132,7 @@ def resize_image_tensor(
antialias=antialias,
)

return image.view(extra_dims + (num_channels, new_height, new_width))
return image.reshape(extra_dims + (num_channels, new_height, new_width))


@torch.jit.unused
Expand Down Expand Up @@ -168,7 +168,7 @@ def resize_bounding_box(
new_height, new_width = _compute_resized_output_size(spatial_size, size=size, max_size=max_size)
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
return (
bounding_box.view(-1, 2, 2).mul(ratios).to(bounding_box.dtype).view(bounding_box.shape),
bounding_box.reshape(-1, 2, 2).mul(ratios).to(bounding_box.dtype).reshape(bounding_box.shape),
(new_height, new_width),
)

Expand Down Expand Up @@ -270,7 +270,7 @@ def affine_image_tensor(

num_channels, height, width = image.shape[-3:]
extra_dims = image.shape[:-3]
image = image.view(-1, num_channels, height, width)
image = image.reshape(-1, num_channels, height, width)

angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)

Expand All @@ -283,7 +283,7 @@ def affine_image_tensor(
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)

output = _FT.affine(image, matrix, interpolation=interpolation.value, fill=fill)
return output.view(extra_dims + (num_channels, height, width))
return output.reshape(extra_dims + (num_channels, height, width))


@torch.jit.unused
Expand Down Expand Up @@ -338,20 +338,20 @@ def _affine_bounding_box_xyxy(
dtype=dtype,
device=device,
)
.view(2, 3)
.reshape(2, 3)
.T
)
# 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
# Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2)
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
# 2) Now let's transform the points using affine matrix
transformed_points = torch.matmul(points, transposed_affine_matrix)
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
# and compute bounding box from 4 transformed points:
transformed_points = transformed_points.view(-1, 4, 2)
transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
Expand Down Expand Up @@ -396,15 +396,15 @@ def affine_bounding_box(
original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
).reshape(-1, 4)

out_bboxes, _ = _affine_bounding_box_xyxy(bounding_box, spatial_size, angle, translate, scale, shear, center)

# out_bboxes should be of shape [N boxes, 4]

return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape)
).reshape(original_shape)


def affine_mask(
Expand Down Expand Up @@ -539,7 +539,7 @@ def rotate_image_tensor(

if image.numel() > 0:
image = _FT.rotate(
image.view(-1, num_channels, height, width),
image.reshape(-1, num_channels, height, width),
matrix,
interpolation=interpolation.value,
expand=expand,
Expand All @@ -549,7 +549,7 @@ def rotate_image_tensor(
else:
new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height)

return image.view(extra_dims + (num_channels, new_height, new_width))
return image.reshape(extra_dims + (num_channels, new_height, new_width))


@torch.jit.unused
Expand Down Expand Up @@ -585,7 +585,7 @@ def rotate_bounding_box(
original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
).reshape(-1, 4)

out_bboxes, spatial_size = _affine_bounding_box_xyxy(
bounding_box,
Expand All @@ -601,7 +601,7 @@ def rotate_bounding_box(
return (
convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape),
).reshape(original_shape),
spatial_size,
)

Expand Down Expand Up @@ -691,15 +691,15 @@ def _pad_with_scalar_fill(

if image.numel() > 0:
image = _FT.pad(
img=image.view(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode
img=image.reshape(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode
)
new_height, new_width = image.shape[-2:]
else:
left, right, top, bottom = _FT._parse_pad_padding(padding)
new_height = height + top + bottom
new_width = width + left + right

return image.view(extra_dims + (num_channels, new_height, new_width))
return image.reshape(extra_dims + (num_channels, new_height, new_width))


# TODO: This should be removed once pytorch pad supports non-scalar padding values
Expand All @@ -714,7 +714,7 @@ def _pad_with_vector_fill(

output = _pad_with_scalar_fill(image, padding, fill=0, padding_mode="constant")
left, right, top, bottom = _parse_pad_padding(padding)
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).view(-1, 1, 1)
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)

if top > 0:
output[..., :top, :] = fill
Expand Down Expand Up @@ -863,15 +863,15 @@ def perspective_image_tensor(
shape = image.shape

if image.ndim > 4:
image = image.view((-1,) + shape[-3:])
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False

output = _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill)

if needs_unsquash:
output = output.view(shape)
output = output.reshape(shape)

return output

Expand All @@ -898,7 +898,7 @@ def perspective_bounding_box(
original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
).reshape(-1, 4)

dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
device = bounding_box.device
Expand Down Expand Up @@ -947,7 +947,7 @@ def perspective_bounding_box(
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
# Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2)
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
# 2) Now let's transform the points using perspective matrices
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
Expand All @@ -959,7 +959,7 @@ def perspective_bounding_box(

# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
# and compute bounding box from 4 transformed points:
transformed_points = transformed_points.view(-1, 4, 2)
transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)
Expand All @@ -968,7 +968,7 @@ def perspective_bounding_box(

return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape)
).reshape(original_shape)


def perspective_mask(
Expand Down Expand Up @@ -1027,15 +1027,15 @@ def elastic_image_tensor(
shape = image.shape

if image.ndim > 4:
image = image.view((-1,) + shape[-3:])
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False

output = _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill)

if needs_unsquash:
output = output.view(shape)
output = output.reshape(shape)

return output

Expand Down Expand Up @@ -1063,7 +1063,7 @@ def elastic_bounding_box(
original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
).reshape(-1, 4)

# Question (vfdev-5): should we rely on good displacement shape and fetch image size from it
# Or add spatial_size arg and check displacement shape
Expand All @@ -1075,21 +1075,21 @@ def elastic_bounding_box(
inv_grid = id_grid - displacement

# Get points from bboxes
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2)
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
index_x = torch.floor(points[:, 0] + 0.5).to(dtype=torch.long)
index_y = torch.floor(points[:, 1] + 0.5).to(dtype=torch.long)
# Transform points:
t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype)
transformed_points = (inv_grid[0, index_y, index_x, :] + 1) * 0.5 * t_size - 0.5

transformed_points = transformed_points.view(-1, 4, 2)
transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)

return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape)
).reshape(original_shape)


def elastic_mask(
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ def gaussian_blur_image_tensor(
shape = image.shape

if image.ndim > 4:
image = image.view((-1,) + shape[-3:])
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False

output = _FT.gaussian_blur(image, kernel_size, sigma)

if needs_unsquash:
output = output.view(shape)
output = output.reshape(shape)

return output

Expand Down