diff --git a/test/test_utils.py b/test/test_utils.py index 3132e90ac87..21e2ab461d7 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,3 +1,4 @@ +import numpy as np import os import sys import tempfile @@ -6,7 +7,6 @@ import unittest from io import BytesIO import torchvision.transforms.functional as F -from torchvision.io.image import read_image, write_png from PIL import Image @@ -90,9 +90,10 @@ def test_draw_boxes(self): path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png") if not os.path.exists(path): - write_png(result, path) + res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) + res.save(path) - expected = read_image(path) + expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) self.assertTrue(torch.equal(result, expected))