Skip to content

Prevent tests from leaking their respective RNG #4497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from common_utils import IN_CIRCLE_CI, CIRCLECI_GPU_NO_CUDA_MSG, IN_FBCODE, IN_RE_WORKER, CUDA_NOT_AVAILABLE_MSG
import torch
import numpy as np
import random
import pytest


Expand Down Expand Up @@ -80,3 +82,26 @@ def pytest_sessionfinish(session, exitstatus):
# To avoid this, we transform this 5 into a 0 to make testpilot happy.
if exitstatus == 5:
session.exitstatus = 0


@pytest.fixture(autouse=True)
def prevent_leaking_rng():
Comment on lines +87 to +88
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty much the same as freeze_rng_state() https://github.com/pytorch/vision/blob/main/test/common_utils.py#L86

autouse=True means that this fixture will be used by every single test automatically.

# Prevent each test from leaking the rng to all other test when they call
# torch.manual_seed() or random.seed() or np.random.seed().
# Note: the numpy rngs should never leak anyway, as we never use
# np.random.seed() and instead rely on np.random.RandomState instances (see
# issue #4247). We still do it for extra precaution.

torch_rng_state = torch.get_rng_state()
builtin_rng_state = random.getstate()
nunmpy_rng_state = np.random.get_state()
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state()

yield

torch.set_rng_state(torch_rng_state)
random.setstate(builtin_rng_state)
np.random.set_state(nunmpy_rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state)
1 change: 1 addition & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,7 @@ def test_random_apply(device):
@pytest.mark.parametrize('channels', [1, 3])
def test_gaussian_blur(device, channels, meth_kwargs):
tol = 1.0 + 1e-10
torch.manual_seed(12)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to add this to prevent the test from failing (see https://app.circleci.com/pipelines/github/pytorch/vision/10830/workflows/946973f0-bad1-48ed-aaf3-f3720ab5f56a/jobs/828462).

If anything, this shows that the PR is working as expected and that this test is a bit flaky, and sensible to the _create_data() call.

To confirm I parametrized it over 100 random seeds, and I got 8 failures over 1000+ test instances. Each time, just 1 pixel was off. Considering the low failure rate, I think it's fine to keep the manual seeding here.

_test_class_op(
T.GaussianBlur, meth_kwargs=meth_kwargs, channels=channels,
test_exact_match=False, device=device, agg_method="max", tol=tol
Expand Down