Skip to content

Commit 79ac4bf

Browse files
committed
Fix tests.
1 parent 9c284cc commit 79ac4bf

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

test/test_transforms_tensor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,12 @@ def test_augmix(device, fill):
739739
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
740740
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
741741

742-
transform = T.AugMix(fill=fill)
742+
class DeterministicAugMix(T.AugMix):
743+
def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
744+
# path the method to ensure that the order of rand calls doesn't affect the outcome
745+
return params.softmax(dim=-1)
746+
747+
transform = DeterministicAugMix(fill=fill)
743748
s_transform = torch.jit.script(transform)
744749
for _ in range(25):
745750
_test_transform_vs_scripted(transform, s_transform, tensor)

torchvision/transforms/autoaugment.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,10 @@ def _pil_to_tensor(self, img) -> Tensor:
535535
def _tensor_to_pil(self, img: Tensor):
536536
return F.to_pil_image(img)
537537

538+
def _sample_dirichlet(self, params: Tensor) -> Tensor:
539+
# Must be on a separate method so that we can overwrite it in tests.
540+
return torch._sample_dirichlet(params)
541+
538542
def forward(self, orig_img: Tensor) -> Tensor:
539543
"""
540544
img (PIL Image or Tensor): Image to be transformed.
@@ -560,12 +564,12 @@ def forward(self, orig_img: Tensor) -> Tensor:
560564

561565
# Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet
562566
# with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image.
563-
m = torch._sample_dirichlet(
567+
m = self._sample_dirichlet(
564568
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
565569
)
566570

567571
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images.
568-
combined_weights = torch._sample_dirichlet(
572+
combined_weights = self._sample_dirichlet(
569573
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
570574
) * m[:, 1].view([batch_dims[0], -1])
571575

0 commit comments

Comments
 (0)