Skip to content

Commit 02ac4ae

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] make fill defaultdict an implementation detail (#7258)
Reviewed By: vmoens Differential Revision: D44416563 fbshipit-source-id: 3ac6e3f6a7b6cfa4766c0a4a50b643d47a35e265
1 parent 1b000ff commit 02ac4ae

File tree

3 files changed

+29
-38
lines changed

3 files changed

+29
-38
lines changed

torchvision/prototype/transforms/_geometry.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def __init__(
2222
self.crop_height = size[0]
2323
self.crop_width = size[1]
2424

25-
self.fill = _setup_fill_arg(fill)
25+
self.fill = fill
26+
self._fill = _setup_fill_arg(fill)
2627

2728
self.padding_mode = padding_mode
2829

@@ -118,7 +119,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
118119
)
119120

120121
if params["needs_pad"]:
121-
fill = self.fill[type(inpt)]
122+
fill = self._fill[type(inpt)]
122123
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
123124

124125
return inpt

torchvision/transforms/v2/_geometry.py

+23-20
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
255255
params = super()._extract_params_for_v1_transform()
256256

257257
if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
258-
raise ValueError(
259-
f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images."
260-
)
258+
raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")
261259

262260
return params
263261

@@ -276,11 +274,12 @@ def __init__(
276274
if not isinstance(padding, int):
277275
padding = list(padding)
278276
self.padding = padding
279-
self.fill = _setup_fill_arg(fill)
277+
self.fill = fill
278+
self._fill = _setup_fill_arg(fill)
280279
self.padding_mode = padding_mode
281280

282281
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
283-
fill = self.fill[type(inpt)]
282+
fill = self._fill[type(inpt)]
284283
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
285284

286285

@@ -293,7 +292,8 @@ def __init__(
293292
) -> None:
294293
super().__init__(p=p)
295294

296-
self.fill = _setup_fill_arg(fill)
295+
self.fill = fill
296+
self._fill = _setup_fill_arg(fill)
297297

298298
_check_sequence_input(side_range, "side_range", req_sizes=(2,))
299299

@@ -318,7 +318,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
318318
return dict(padding=padding)
319319

320320
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
321-
fill = self.fill[type(inpt)]
321+
fill = self._fill[type(inpt)]
322322
return F.pad(inpt, **params, fill=fill)
323323

324324

@@ -338,7 +338,8 @@ def __init__(
338338
self.interpolation = _check_interpolation(interpolation)
339339
self.expand = expand
340340

341-
self.fill = _setup_fill_arg(fill)
341+
self.fill = fill
342+
self._fill = _setup_fill_arg(fill)
342343

343344
if center is not None:
344345
_check_sequence_input(center, "center", req_sizes=(2,))
@@ -350,7 +351,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
350351
return dict(angle=angle)
351352

352353
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
353-
fill = self.fill[type(inpt)]
354+
fill = self._fill[type(inpt)]
354355
return F.rotate(
355356
inpt,
356357
**params,
@@ -395,7 +396,8 @@ def __init__(
395396
self.shear = shear
396397

397398
self.interpolation = _check_interpolation(interpolation)
398-
self.fill = _setup_fill_arg(fill)
399+
self.fill = fill
400+
self._fill = _setup_fill_arg(fill)
399401

400402
if center is not None:
401403
_check_sequence_input(center, "center", req_sizes=(2,))
@@ -430,7 +432,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
430432
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
431433

432434
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
433-
fill = self.fill[type(inpt)]
435+
fill = self._fill[type(inpt)]
434436
return F.affine(
435437
inpt,
436438
**params,
@@ -447,9 +449,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
447449
params = super()._extract_params_for_v1_transform()
448450

449451
if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
450-
raise ValueError(
451-
f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images."
452-
)
452+
raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")
453453

454454
padding = self.padding
455455
if padding is not None:
@@ -478,7 +478,8 @@ def __init__(
478478

479479
self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type]
480480
self.pad_if_needed = pad_if_needed
481-
self.fill = _setup_fill_arg(fill)
481+
self.fill = fill
482+
self._fill = _setup_fill_arg(fill)
482483
self.padding_mode = padding_mode
483484

484485
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
@@ -541,7 +542,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
541542

542543
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
543544
if params["needs_pad"]:
544-
fill = self.fill[type(inpt)]
545+
fill = self._fill[type(inpt)]
545546
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
546547

547548
if params["needs_crop"]:
@@ -567,7 +568,8 @@ def __init__(
567568

568569
self.distortion_scale = distortion_scale
569570
self.interpolation = _check_interpolation(interpolation)
570-
self.fill = _setup_fill_arg(fill)
571+
self.fill = fill
572+
self._fill = _setup_fill_arg(fill)
571573

572574
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
573575
height, width = query_spatial_size(flat_inputs)
@@ -600,7 +602,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
600602
return dict(coefficients=perspective_coeffs)
601603

602604
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
603-
fill = self.fill[type(inpt)]
605+
fill = self._fill[type(inpt)]
604606
return F.perspective(
605607
inpt,
606608
None,
@@ -626,7 +628,8 @@ def __init__(
626628
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
627629

628630
self.interpolation = _check_interpolation(interpolation)
629-
self.fill = _setup_fill_arg(fill)
631+
self.fill = fill
632+
self._fill = _setup_fill_arg(fill)
630633

631634
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
632635
size = list(query_spatial_size(flat_inputs))
@@ -652,7 +655,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
652655
return dict(displacement=displacement)
653656

654657
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
655-
fill = self.fill[type(inpt)]
658+
fill = self._fill[type(inpt)]
656659
return F.elastic(
657660
inpt,
658661
**params,

torchvision/transforms/v2/_transform.py

+3-16
Original file line numberDiff line numberDiff line change
@@ -108,30 +108,17 @@ def __init_subclass__(cls) -> None:
108108

109109
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
110110
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
111-
# v2 transform instance. It does two things:
112-
# 1. Extract all available public attributes that are specific to that transform and not `nn.Module` in general
113-
# 2. If available handle the `fill` attribute for v1 compatibility (see below for details)
111+
# v2 transform instance. It extracts all available public attributes that are specific to that transform and
112+
# not `nn.Module` in general.
114113
# Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen
115114
# if the v2 transform introduced new parameters that are not support by the v1 transform.
116115
common_attrs = nn.Module().__dict__.keys()
117-
params = {
116+
return {
118117
attr: value
119118
for attr, value in self.__dict__.items()
120119
if not attr.startswith("_") and attr not in common_attrs
121120
}
122121

123-
# transforms v2 has a more complex handling for the `fill` parameter than v1. By default, the input is parsed
124-
# with `prototype.transforms._utils._setup_fill_arg()`, which returns a defaultdict that holds the fill value
125-
# for the different datapoint types. Below we extract the value for tensors and return that together with the
126-
# other params.
127-
# This is needed for `Pad`, `ElasticTransform`, `RandomAffine`, `RandomCrop`, `RandomPerspective` and
128-
# `RandomRotation`
129-
if "fill" in params:
130-
fill_type_defaultdict = params.pop("fill")
131-
params["fill"] = fill_type_defaultdict[torch.Tensor]
132-
133-
return params
134-
135122
def __prepare_scriptable__(self) -> nn.Module:
136123
# This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return
137124
# value is used for scripting over the original object that should have been scripted. Since the v1 transforms

0 commit comments

Comments
 (0)