Skip to content

Commit a46d97c

Browse files
pmeiervfdev-5
andauthored
align transforms v2 signatures with v1 (#7301)
Co-authored-by: vfdev <[email protected]>
1 parent 49c6961 commit a46d97c

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

test/test_transforms_v2_consistency.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -540,9 +540,12 @@ def test_signature_consistency(config):
540540
f"not. Please add a default value."
541541
)
542542

543-
legacy_kinds = {name: param.kind for name, param in legacy_params.items()}
544-
prototype_kinds = {name: prototype_params[name].kind for name in legacy_kinds.keys()}
545-
assert prototype_kinds == legacy_kinds
543+
legacy_signature = list(legacy_params.keys())
544+
# Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
545+
# to the same number of parameters as the legacy one
546+
prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]
547+
548+
assert prototype_signature == legacy_signature
546549

547550

548551
def check_call_consistency(

torchvision/transforms/v2/_container.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ class RandomChoice(Transform):
124124
def __init__(
125125
self,
126126
transforms: Sequence[Callable],
127-
probabilities: Optional[List[float]] = None,
128127
p: Optional[List[float]] = None,
128+
probabilities: Optional[List[float]] = None,
129129
) -> None:
130130
if not isinstance(transforms, Sequence):
131131
raise TypeError("Argument transforms should be a sequence of callables")

torchvision/transforms/v2/_geometry.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -575,8 +575,8 @@ def __init__(
575575
degrees: Union[numbers.Number, Sequence],
576576
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
577577
expand: bool = False,
578-
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
579578
center: Optional[List[float]] = None,
579+
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
580580
) -> None:
581581
super().__init__()
582582
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
@@ -903,9 +903,9 @@ class RandomPerspective(_RandomApplyTransform):
903903
def __init__(
904904
self,
905905
distortion_scale: float = 0.5,
906-
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
907-
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
908906
p: float = 0.5,
907+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
908+
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
909909
) -> None:
910910
super().__init__(p=p)
911911

@@ -966,8 +966,8 @@ def __init__(
966966
self,
967967
alpha: Union[float, Sequence[float]] = 50.0,
968968
sigma: Union[float, Sequence[float]] = 5.0,
969-
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
970969
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
970+
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
971971
) -> None:
972972
super().__init__()
973973
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)

0 commit comments

Comments
 (0)