Skip to content

Commit e1aacdd

Browse files
authored
Update ToDtype to avoid unnecessary to() calls and fixing types on Transform (#6773)
* Fix `ToDtype` to avoid errors when a type is not defined. * Nit `(features.is_simple_tensor, features._Feature)` to `(Tensor,)` * Fixing linter * Adding comment. * Switch back to indexing. Python's default dict seems to have a nasty behaviour.
1 parent 8ec7a70 commit e1aacdd

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

torchvision/prototype/transforms/_misc.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,10 @@ def __init__(self, dtype: Union[torch.dtype, Dict[Type, torch.dtype]]) -> None:
157157
self.dtype = dtype
158158

159159
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
160-
return inpt.to(self.dtype[type(inpt)])
160+
dtype = self.dtype[type(inpt)]
161+
if dtype is None:
162+
return inpt
163+
return inpt.to(dtype=dtype)
161164

162165

163166
class RemoveSmallBoundingBoxes(Transform):

torchvision/prototype/transforms/_transform.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,15 @@
55
import torch
66
from torch import nn
77
from torch.utils._pytree import tree_flatten, tree_unflatten
8-
from torchvision.prototype import features
98
from torchvision.prototype.transforms._utils import _isinstance
109
from torchvision.utils import _log_api_usage_once
1110

1211

1312
class Transform(nn.Module):
1413

1514
# Class attribute defining transformed types. Other types are passed-through without any transformation
16-
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (
17-
features.is_simple_tensor,
18-
features._Feature,
19-
PIL.Image.Image,
20-
)
15+
# We support both Types and callables that are able to do further checks on the type of the input.
16+
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image)
2117

2218
def __init__(self) -> None:
2319
super().__init__()

0 commit comments

Comments
 (0)