Skip to content

Simplify the assertExpected method #2965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 5, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 31 additions & 40 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,7 @@ def is_iterable(obj):
class TestCase(unittest.TestCase):
precision = 1e-5

def assertExpected(self, output, subname=None, prec=None, strip_suffix=None):
r"""
Test that a python value matches the recorded contents of a file
derived from the name of this test and subname. The value must be
pickable with `torch.save`. This file
is placed in the 'expect' directory in the same directory
as the test script. You can automatically update the recorded test
output using --accept.

If you call this multiple times in a single function, you must
give a unique subname each time.

strip_suffix allows different tests that expect similar numerics, e.g.
"test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data.
test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass
strip_suffix="_cpu", and they would both use a data file name based on
"test_xyz".
"""
def _get_expected_file(self, subname=None, strip_suffix=None):
def remove_prefix_suffix(text, prefix, suffix):
if text.startswith(prefix):
text = text[len(prefix):]
Expand All @@ -128,33 +111,41 @@ def remove_prefix_suffix(text, prefix, suffix):
subname_output = " ({})".format(subname)
expected_file += "_expect.pkl"

def accept_output(update_type):
print("Accepting {} for {}{}:\n\n{}".format(update_type, munged_id, subname_output, output))
if not ACCEPT and not os.path.exists(expected_file):
raise RuntimeError(
("No expect file exists for {}{}; to accept the current output, run:\n"
"python {} {} --accept").format(munged_id, subname_output, __main__.__file__, munged_id))

return expected_file

def assertExpected(self, output, subname=None, prec=None, strip_suffix=None):
r"""
Test that a python value matches the recorded contents of a file
derived from the name of this test and subname. The value must be
pickable with `torch.save`. This file
is placed in the 'expect' directory in the same directory
as the test script. You can automatically update the recorded test
output using --accept.

If you call this multiple times in a single function, you must
give a unique subname each time.

strip_suffix allows different tests that expect similar numerics, e.g.
"test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data.
test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass
strip_suffix="_cpu", and they would both use a data file name based on
"test_xyz".
"""
expected_file = self._get_expected_file(subname, strip_suffix)

if ACCEPT:
print("Accepting updated output for {}:\n\n{}".format(os.path.basename(expected_file), output))
torch.save(output, expected_file)
MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
binary_size = os.path.getsize(expected_file)
self.assertTrue(binary_size <= MAX_PICKLE_SIZE)

try:
expected = torch.load(expected_file)
except IOError as e:
if e.errno != errno.ENOENT:
raise
elif ACCEPT:
accept_output("output")
return
else:
raise RuntimeError(
("I got this output for {}{}:\n\n{}\n\n"
"No expect file exists; to accept the current output, run:\n"
"python {} {} --accept").format(munged_id, subname_output, output, __main__.__file__, munged_id))

if ACCEPT:
try:
self.assertEqual(output, expected, prec=prec)
except Exception:
accept_output("updated output")
else:
expected = torch.load(expected_file)
self.assertEqual(output, expected, prec=prec)

def assertEqual(self, x, y, prec=None, message='', allow_inf=False):
Expand Down