Skip to content

Commit fdca307

Browse files
vfdev-5nairbv
andauthored
Added CPU/CUDA and batch input for dtype conversion op (#2755)
* make convert_image_dtype scriptable * move convert dtype to functional_tensor since only works on tensors * retain availability of convert_image_dtype in functional.py * Update code and tests * Replaced int by torch.dtype * int -> torch.dtype and use F instead of F_t * Update functional_tensor.py * Added CPU/CUDA+batch tests * Fixed tests according to review Co-authored-by: Brian <[email protected]>
1 parent 3d0c779 commit fdca307

File tree

3 files changed

+48
-22
lines changed

3 files changed

+48
-22
lines changed

test/common_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,3 +369,16 @@ def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, agg_meth
369369
err < tol,
370370
msg="{}: err={}, tol={}: \n{}\nvs\n{}".format(msg, err, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
371371
)
372+
373+
374+
def cycle_over(objs):
375+
for idx, obj in enumerate(objs):
376+
yield obj, objs[:idx] + objs[idx + 1:]
377+
378+
379+
def int_dtypes():
380+
return torch.testing.integral_types()
381+
382+
383+
def float_dtypes():
384+
return torch.testing.floating_types()

test/test_transforms.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,11 @@
2020
except ImportError:
2121
stats = None
2222

23-
GRACE_HOPPER = get_file_path_2(
24-
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
25-
26-
27-
def cycle_over(objs):
28-
objs = list(objs)
29-
for idx, obj in enumerate(objs):
30-
yield obj, objs[:idx] + objs[idx + 1:]
23+
from common_utils import cycle_over, int_dtypes, float_dtypes
3124

3225

33-
def int_dtypes():
34-
yield from iter(
35-
(torch.uint8, torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,)
36-
)
37-
38-
39-
def float_dtypes():
40-
yield from iter((torch.float32, torch.float, torch.float64, torch.double))
26+
GRACE_HOPPER = get_file_path_2(
27+
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
4128

4229

4330
class Tester(unittest.TestCase):

test/test_transforms_tensor.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import unittest
1111

12-
from common_utils import TransformsTester, get_tmp_dir
12+
from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes
1313

1414

1515
class Tester(TransformsTester):
@@ -27,26 +27,26 @@ def _test_functional_op(self, func, fn_kwargs):
2727
transformed_pil_img = f(pil_img, **fn_kwargs)
2828
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
2929

30-
def _test_transform_vs_scripted(self, transform, s_transform, tensor):
30+
def _test_transform_vs_scripted(self, transform, s_transform, tensor, msg=None):
3131
torch.manual_seed(12)
3232
out1 = transform(tensor)
3333
torch.manual_seed(12)
3434
out2 = s_transform(tensor)
35-
self.assertTrue(out1.equal(out2))
35+
self.assertTrue(out1.equal(out2), msg=msg)
3636

37-
def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors):
37+
def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors, msg=None):
3838
torch.manual_seed(12)
3939
transformed_batch = transform(batch_tensors)
4040

4141
for i in range(len(batch_tensors)):
4242
img_tensor = batch_tensors[i, ...]
4343
torch.manual_seed(12)
4444
transformed_img = transform(img_tensor)
45-
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]))
45+
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]), msg=msg)
4646

4747
torch.manual_seed(12)
4848
s_transformed_batch = s_transform(batch_tensors)
49-
self.assertTrue(transformed_batch.equal(s_transformed_batch))
49+
self.assertTrue(transformed_batch.equal(s_transformed_batch), msg=msg)
5050

5151
def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
5252
if meth_kwargs is None:
@@ -492,6 +492,32 @@ def test_random_erasing(self):
492492
self._test_transform_vs_scripted(fn, scripted_fn, tensor)
493493
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
494494

495+
def test_convert_image_dtype(self):
496+
tensor, _ = self._create_data(26, 34, device=self.device)
497+
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
498+
499+
for in_dtype in int_dtypes() + float_dtypes():
500+
in_tensor = tensor.to(in_dtype)
501+
in_batch_tensors = batch_tensors.to(in_dtype)
502+
for out_dtype in int_dtypes() + float_dtypes():
503+
504+
fn = T.ConvertImageDtype(dtype=out_dtype)
505+
scripted_fn = torch.jit.script(fn)
506+
507+
if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or \
508+
(in_dtype == torch.float64 and out_dtype == torch.int64):
509+
with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
510+
self._test_transform_vs_scripted(fn, scripted_fn, in_tensor)
511+
with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
512+
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
513+
continue
514+
515+
self._test_transform_vs_scripted(fn, scripted_fn, in_tensor)
516+
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
517+
518+
with get_tmp_dir() as tmp_dir:
519+
scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt"))
520+
495521

496522
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
497523
class CUDATester(Tester):

0 commit comments

Comments
 (0)