Skip to content

Commit c9a3e9f

Browse files
committed
Some more
1 parent 2d83526 commit c9a3e9f

File tree

7 files changed

+135
-67
lines changed

7 files changed

+135
-67
lines changed

test/test_transforms_v2_refactored.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import inspect
44
import math
55
import re
6+
import sys
67
from pathlib import Path
78
from unittest import mock
89

@@ -2381,6 +2382,7 @@ def test_correctness(self):
23812382
assert isinstance(out_value, type(input_value))
23822383

23832384

2385+
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="Windows doesn't support fork()")
23842386
def test_transforms_rng_with_dataloader():
23852387
# This is more of a sanity test for torch core's handling of Generators within the Dataloader
23862388
# But worth having it here as well for security.

torchvision/transforms/v2/_augment.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ def __init__(
6363
ratio: Tuple[float, float] = (0.3, 3.3),
6464
value: float = 0.0,
6565
inplace: bool = False,
66+
generator=None,
6667
):
67-
super().__init__(p=p)
68+
super().__init__(p=p, generator=generator)
6869
if not isinstance(value, (numbers.Number, str, tuple, list)):
6970
raise TypeError("Argument value should be either a number or str or a sequence")
7071
if isinstance(value, str) and value != "random":
@@ -111,11 +112,12 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
111112

112113
log_ratio = self._log_ratio
113114
for _ in range(10):
114-
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
115+
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=self.generator).item()
115116
aspect_ratio = torch.exp(
116117
torch.empty(1).uniform_(
117118
log_ratio[0], # type: ignore[arg-type]
118119
log_ratio[1], # type: ignore[arg-type]
120+
generator=self.generator,
119121
)
120122
).item()
121123

@@ -129,8 +131,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
129131
else:
130132
v = torch.tensor(self.value)[:, None, None]
131133

132-
i = torch.randint(0, img_h - h + 1, size=(1,)).item()
133-
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
134+
i = torch.randint(0, img_h - h + 1, size=(1,), generator=self.generator).item()
135+
j = torch.randint(0, img_w - w + 1, size=(1,), generator=self.generator).item()
134136
break
135137
else:
136138
i, j, h, w, v = 0, 0, img_h, img_w, None

torchvision/transforms/v2/_auto_augment.py

+24-14
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,16 @@ def __init__(
2525
*,
2626
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
2727
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
28+
generator=None,
2829
) -> None:
2930
super().__init__()
3031
self.interpolation = _check_interpolation(interpolation)
3132
self.fill = _setup_fill_arg(fill)
33+
self.generator = generator
3234

3335
def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
3436
keys = tuple(dct.keys())
35-
key = keys[int(torch.randint(len(keys), ()))]
37+
key = keys[int(torch.randint(len(keys), (), generator=self.generator))]
3638
return key, dct[key]
3739

3840
def _flatten_and_extract_image_or_video(
@@ -219,8 +221,9 @@ def __init__(
219221
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
220222
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
221223
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
224+
generator=None,
222225
) -> None:
223-
super().__init__(interpolation=interpolation, fill=fill)
226+
super().__init__(interpolation=interpolation, fill=fill, generator=generator)
224227
self.policy = policy
225228
self._policies = self._get_policies(policy)
226229

@@ -318,18 +321,18 @@ def forward(self, *inputs: Any) -> Any:
318321
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
319322
height, width = get_size(image_or_video)
320323

321-
policy = self._policies[int(torch.randint(len(self._policies), ()))]
324+
policy = self._policies[int(torch.randint(len(self._policies), (), generator=self.generator))]
322325

323326
for transform_id, probability, magnitude_idx in policy:
324-
if not torch.rand(()) <= probability:
327+
if not torch.rand((), generator=self.generator) <= probability:
325328
continue
326329

327330
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
328331

329332
magnitudes = magnitudes_fn(10, height, width)
330333
if magnitudes is not None:
331334
magnitude = float(magnitudes[magnitude_idx])
332-
if signed and torch.rand(()) <= 0.5:
335+
if signed and torch.rand((), generator=self.generator) <= 0.5:
333336
magnitude *= -1
334337
else:
335338
magnitude = 0.0
@@ -399,8 +402,9 @@ def __init__(
399402
num_magnitude_bins: int = 31,
400403
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
401404
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
405+
generator=None,
402406
) -> None:
403-
super().__init__(interpolation=interpolation, fill=fill)
407+
super().__init__(interpolation=interpolation, fill=fill, generator=generator)
404408
self.num_ops = num_ops
405409
self.magnitude = magnitude
406410
self.num_magnitude_bins = num_magnitude_bins
@@ -414,7 +418,7 @@ def forward(self, *inputs: Any) -> Any:
414418
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
415419
if magnitudes is not None:
416420
magnitude = float(magnitudes[self.magnitude])
417-
if signed and torch.rand(()) <= 0.5:
421+
if signed and torch.rand((), generator=self.generator) <= 0.5:
418422
magnitude *= -1
419423
else:
420424
magnitude = 0.0
@@ -472,8 +476,9 @@ def __init__(
472476
num_magnitude_bins: int = 31,
473477
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
474478
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
479+
generator=None,
475480
):
476-
super().__init__(interpolation=interpolation, fill=fill)
481+
super().__init__(interpolation=interpolation, fill=fill, generator=generator)
477482
self.num_magnitude_bins = num_magnitude_bins
478483

479484
def forward(self, *inputs: Any) -> Any:
@@ -484,8 +489,8 @@ def forward(self, *inputs: Any) -> Any:
484489

485490
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
486491
if magnitudes is not None:
487-
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
488-
if signed and torch.rand(()) <= 0.5:
492+
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, (), generator=self.generator))])
493+
if signed and torch.rand((), generator=self.generator) <= 0.5:
489494
magnitude *= -1
490495
else:
491496
magnitude = 0.0
@@ -555,8 +560,9 @@ def __init__(
555560
all_ops: bool = True,
556561
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
557562
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
563+
generator=None,
558564
) -> None:
559-
super().__init__(interpolation=interpolation, fill=fill)
565+
super().__init__(interpolation=interpolation, fill=fill, generator=generator)
560566
self._PARAMETER_MAX = 10
561567
if not (1 <= severity <= self._PARAMETER_MAX):
562568
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
@@ -601,14 +607,18 @@ def forward(self, *inputs: Any) -> Any:
601607
mix = m[:, 0].reshape(batch_dims) * batch
602608
for i in range(self.mixture_width):
603609
aug = batch
604-
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
610+
depth = (
611+
self.chain_depth
612+
if self.chain_depth > 0
613+
else int(torch.randint(low=1, high=4, size=(1,), generator=self.generator).item())
614+
)
605615
for _ in range(depth):
606616
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
607617

608618
magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
609619
if magnitudes is not None:
610-
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
611-
if signed and torch.rand(()) <= 0.5:
620+
magnitude = float(magnitudes[int(torch.randint(self.severity, (), generator=self.generator))])
621+
if signed and torch.rand((), generator=self.generator) <= 0.5:
612622
magnitude *= -1
613623
else:
614624
magnitude = 0.0

torchvision/transforms/v2/_color.py

+38-12
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,14 @@ def __init__(
9696
contrast: Optional[Union[float, Sequence[float]]] = None,
9797
saturation: Optional[Union[float, Sequence[float]]] = None,
9898
hue: Optional[Union[float, Sequence[float]]] = None,
99+
generator=None,
99100
) -> None:
100101
super().__init__()
101102
self.brightness = self._check_input(brightness, "brightness")
102103
self.contrast = self._check_input(contrast, "contrast")
103104
self.saturation = self._check_input(saturation, "saturation")
104105
self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
106+
self.generator = generator
105107

106108
def _check_input(
107109
self,
@@ -131,16 +133,28 @@ def _check_input(
131133
return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))
132134

133135
@staticmethod
134-
def _generate_value(left: float, right: float) -> float:
135-
return torch.empty(1).uniform_(left, right).item()
136+
def _generate_value(left: float, right: float, generator=None) -> float:
137+
return torch.empty(1).uniform_(left, right, generator=generator).item()
136138

137139
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
138-
fn_idx = torch.randperm(4)
139-
140-
b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1])
141-
c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1])
142-
s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1])
143-
h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1])
140+
fn_idx = torch.randperm(4, generator=self.generator)
141+
142+
b = (
143+
None
144+
if self.brightness is None
145+
else self._generate_value(self.brightness[0], self.brightness[1], generator=self.generator)
146+
)
147+
c = (
148+
None
149+
if self.contrast is None
150+
else self._generate_value(self.contrast[0], self.contrast[1], generator=self.generator)
151+
)
152+
s = (
153+
None
154+
if self.saturation is None
155+
else self._generate_value(self.saturation[0], self.saturation[1], generator=self.generator)
156+
)
157+
h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1], generator=self.generator)
144158

145159
return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h)
146160

@@ -168,9 +182,13 @@ class RandomChannelPermutation(Transform):
168182
.. v2betastatus:: RandomChannelPermutation transform
169183
"""
170184

185+
def __init__(self, generator=None):
186+
super().__init__()
187+
self.generator = generator
188+
171189
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
172190
num_channels, *_ = query_chw(flat_inputs)
173-
return dict(permutation=torch.randperm(num_channels))
191+
return dict(permutation=torch.randperm(num_channels, generator=self.generator))
174192

175193
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
176194
return self._call_kernel(F.permute_channels, inpt, params["permutation"])
@@ -209,27 +227,35 @@ def __init__(
209227
saturation: Tuple[float, float] = (0.5, 1.5),
210228
hue: Tuple[float, float] = (-0.05, 0.05),
211229
p: float = 0.5,
230+
generator=None,
212231
):
213232
super().__init__()
214233
self.brightness = brightness
215234
self.contrast = contrast
216235
self.hue = hue
217236
self.saturation = saturation
218237
self.p = p
238+
self.generator = generator
219239

220240
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
221241
num_channels, *_ = query_chw(flat_inputs)
222242
params: Dict[str, Any] = {
223-
key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None
243+
key: ColorJitter._generate_value(range[0], range[1])
244+
if torch.rand(1, generator=self.generator) < self.p
245+
else None
224246
for key, range in [
225247
("brightness_factor", self.brightness),
226248
("contrast_factor", self.contrast),
227249
("saturation_factor", self.saturation),
228250
("hue_factor", self.hue),
229251
]
230252
}
231-
params["contrast_before"] = bool(torch.rand(()) < 0.5)
232-
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
253+
params["contrast_before"] = bool(torch.rand((), generator=self.generator) < 0.5)
254+
params["channel_permutation"] = (
255+
torch.randperm(num_channels, generator=self.generator)
256+
if torch.rand(1, generator=self.generator) < self.p
257+
else None
258+
)
233259
return params
234260

235261
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:

torchvision/transforms/v2/_container.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class RandomApply(Transform):
8585

8686
_v1_transform_cls = _transforms.RandomApply
8787

88-
def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
88+
def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5, generator=None) -> None:
8989
super().__init__()
9090

9191
if not isinstance(transforms, (Sequence, nn.ModuleList)):
@@ -95,14 +95,15 @@ def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: floa
9595
if not (0.0 <= p <= 1.0):
9696
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
9797
self.p = p
98+
self.generator = generator
9899

99100
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
100101
return {"transforms": self.transforms, "p": self.p}
101102

102103
def forward(self, *inputs: Any) -> Any:
103104
sample = inputs if len(inputs) > 1 else inputs[0]
104105

105-
if torch.rand(1) >= self.p:
106+
if torch.rand(1, generator=self.generator) >= self.p:
106107
return sample
107108

108109
for transform in self.transforms:
@@ -166,15 +167,16 @@ class RandomOrder(Transform):
166167
transforms (sequence or torch.nn.Module): list of transformations
167168
"""
168169

169-
def __init__(self, transforms: Sequence[Callable]) -> None:
170+
def __init__(self, transforms: Sequence[Callable], generator=None) -> None:
170171
if not isinstance(transforms, Sequence):
171172
raise TypeError("Argument transforms should be a sequence of callables")
172173
super().__init__()
173174
self.transforms = transforms
175+
self.generator = generator
174176

175177
def forward(self, *inputs: Any) -> Any:
176178
sample = inputs if len(inputs) > 1 else inputs[0]
177-
for idx in torch.randperm(len(self.transforms)):
179+
for idx in torch.randperm(len(self.transforms), generator=self.generator):
178180
transform = self.transforms[idx]
179181
sample = transform(sample)
180182
return sample

0 commit comments

Comments
 (0)